From 1f3e1131f8ecd71e122ad55ee00999c2a9f9023e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:43:26 +0100 Subject: [PATCH 01/15] add additional in-list tests: --- .../physical-expr/src/expressions/in_list.rs | 606 ++++++++++++++++++ 1 file changed, 606 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 81c2bd17a8d6..c5d6fef946e6 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1024,6 +1024,612 @@ mod tests { Ok(()) } + #[test] + fn in_list_int8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int8, true)]); + let a = Int8Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i8), lit(1i8)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i8), lit(1i8)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_int16() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int16, true)]); + let a = Int16Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i16), lit(1i16)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i16), lit(1i16)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i32), lit(1i32)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i32), lit(1i32)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt8, true)]); + let a = UInt8Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u8), lit(1u8)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u8), lit(1u8)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint16() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt16, true)]); + let a = UInt16Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u16), lit(1u16)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u16), lit(1u16)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, true)]); + let a = UInt32Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u32), lit(1u32)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u32), lit(1u32)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt64, true)]); + let a = UInt64Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u64), lit(1u64)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u64), lit(1u64)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_large_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::LargeUtf8, true)]); + let a = LargeStringArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_utf8_view() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8View, true)]); + let a = StringViewArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_large_binary() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::LargeBinary, true)]); + let a = LargeBinaryArray::from(vec![ + Some([1, 2, 3].as_slice()), + Some([1, 2, 2].as_slice()), + None, + ]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ([1, 2, 3], [4, 5, 6])" + let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6])" + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + let list = vec![ + lit([1, 2, 3].as_slice()), + lit([4, 5, 6].as_slice()), + lit(ScalarValue::LargeBinary(None)), + ]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_binary_view() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::BinaryView, true)]); + let a = BinaryViewArray::from(vec![ + Some([1, 2, 3].as_slice()), + Some([1, 2, 2].as_slice()), + None, + ]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ([1, 2, 3], [4, 5, 6])" + let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6])" + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + let list = vec![ + lit([1, 2, 3].as_slice()), + lit([4, 5, 6].as_slice()), + lit(ScalarValue::BinaryView(None)), + ]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + #[test] fn in_list_date64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]); From 15f41939c7f4f8bc0c626432c072a887ef6329c5 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 2 Dec 2025 19:06:55 +0100 Subject: [PATCH 02/15] refactor and show bugs --- .../physical-expr/src/expressions/in_list.rs | 109 ++++++++++++++---- 1 file changed, 86 insertions(+), 23 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index c5d6fef946e6..69b9dc135020 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -355,6 +355,38 @@ impl InListExpr { Some(instantiate_static_filter(array)?), )) } + + /// Create a new InList expression with a static filter for constant list expressions. + /// + /// This validates data types, evaluates the list as constants, and uses specialized + /// StaticFilter implementations for better performance (e.g., Int32StaticFilter for Int32). + /// + /// Returns an error if data types don't match or if the list contains non-constant expressions. + pub fn try_from_static_filter( + expr: Arc, + list: Vec>, + negated: bool, + schema: &Schema, + ) -> Result { + // Check the data types match + let expr_data_type = expr.data_type(schema)?; + for list_expr in list.iter() { + let list_expr_data_type = list_expr.data_type(schema)?; + assert_or_internal_err!( + DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), + "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" + ); + } + + // Evaluate the list as constants and create the static filter + let in_array = try_evaluate_constant_list(&list, schema)?; + Ok(Self::new( + expr, + list, + negated, + Some(instantiate_static_filter(in_array)?), + )) + } } impl std::fmt::Display for InListExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -556,30 +588,28 @@ pub fn in_list( negated: &bool, schema: &Schema, ) -> Result> { - // check the data type - let expr_data_type = expr.data_type(schema)?; - for list_expr in list.iter() { - let list_expr_data_type = list_expr.data_type(schema)?; - assert_or_internal_err!( - DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), - "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" - ); - } - - // Try to create a static filter for constant expressions - let static_filter = try_evaluate_constant_list(&list, schema) - .and_then(ArrayStaticFilter::try_new) - .ok() - .map(|static_filter| { - Arc::new(static_filter) as Arc - }); - - Ok(Arc::new(InListExpr::new( - expr, - list, + // Try to create with static filter (validates types and evaluates constant list) + match InListExpr::try_from_static_filter( + Arc::clone(&expr), + list.clone(), *negated, - static_filter, - ))) + schema, + ) { + Ok(expr) => Ok(Arc::new(expr)), + Err(_) => { + // Fall back to non-static filter if list contains non-constant expressions + // Still need to validate types + let expr_data_type = expr.data_type(schema)?; + for list_expr in list.iter() { + let list_expr_data_type = list_expr.data_type(schema)?; + assert_or_internal_err!( + DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), + "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" + ); + } + Ok(Arc::new(InListExpr::new(expr, list, *negated, None))) + } + } } #[cfg(test)] @@ -3421,4 +3451,37 @@ mod tests { Ok(()) } + + #[test] + fn test_in_list_dictionary_int32() -> Result<()> { + // Test that Int32StaticFilter handles dictionary-encoded Int32 columns. + // This exposes a bug where as_primitive_opt::() returns None + // for dictionary arrays, causing an error. + + // Create schema with dictionary-encoded Int32 column + let dict_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)); + let schema = Schema::new(vec![Field::new("a", dict_type.clone(), false)]); + let col_a = col("a", &schema)?; + + // Create IN list with Int32 literals: (1, 2, 3) + let list = vec![lit(1i32), lit(2i32), lit(3i32)]; + + // Create InListExpr via in_list() - this uses Int32StaticFilter for Int32 lists + let expr = in_list(col_a, list, &false, &schema)?; + + // Create dictionary-encoded batch with values [1, 2, 5] + // Dictionary: keys [0, 1, 2] -> values [1, 2, 5] + let keys = Int8Array::from(vec![0, 1, 2]); + let values = Int32Array::from(vec![1, 2, 5]); + let dict_array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values))?); + let batch = RecordBatch::try_new(Arc::new(schema), vec![dict_array])?; + + // Expected: [1 IN (1,2,3), 2 IN (1,2,3), 5 IN (1,2,3)] = [true, true, false] + let result = expr.evaluate(&batch)?.into_array(3)?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(vec![true, true, false])); + Ok(()) + } } From b4b23ed174e9b598e187cea286c3546474fa4d95 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 2 Dec 2025 19:11:05 +0100 Subject: [PATCH 03/15] refactor --- .../physical-expr/src/expressions/in_list.rs | 49 +++++++++---------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 69b9dc135020..91f190e9633c 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -378,14 +378,27 @@ impl InListExpr { ); } - // Evaluate the list as constants and create the static filter - let in_array = try_evaluate_constant_list(&list, schema)?; - Ok(Self::new( - expr, - list, - negated, - Some(instantiate_static_filter(in_array)?), - )) + match try_evaluate_constant_list(&list, schema) { + Ok(in_array) => Ok(Self::new( + expr, + list, + negated, + Some(instantiate_static_filter(in_array)?), + )), + Err(_) => { + // Fall back to non-static filter if list contains non-constant expressions + // Still need to validate types + let expr_data_type = expr.data_type(schema)?; + for list_expr in list.iter() { + let list_expr_data_type = list_expr.data_type(schema)?; + assert_or_internal_err!( + DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), + "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" + ); + } + Ok(Self::new(expr, list, negated, None)) + } + } } } impl std::fmt::Display for InListExpr { @@ -588,28 +601,12 @@ pub fn in_list( negated: &bool, schema: &Schema, ) -> Result> { - // Try to create with static filter (validates types and evaluates constant list) - match InListExpr::try_from_static_filter( + Ok(Arc::new(InListExpr::try_from_static_filter( Arc::clone(&expr), list.clone(), *negated, schema, - ) { - Ok(expr) => Ok(Arc::new(expr)), - Err(_) => { - // Fall back to non-static filter if list contains non-constant expressions - // Still need to validate types - let expr_data_type = expr.data_type(schema)?; - for list_expr in list.iter() { - let list_expr_data_type = list_expr.data_type(schema)?; - assert_or_internal_err!( - DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), - "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" - ); - } - Ok(Arc::new(InListExpr::new(expr, list, *negated, None))) - } - } + )?)) } #[cfg(test)] From c0ce5aca3bf0d107f165ca9e3df82b465b48399b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:55:35 +0100 Subject: [PATCH 04/15] fixes --- .../physical-expr/src/expressions/in_list.rs | 60 ++++++++++++++----- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 91f190e9633c..b343c5d31d5d 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -226,31 +226,61 @@ impl StaticFilter for Int32StaticFilter { } fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + let v = v .as_primitive_opt::() .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - let result = match (v.null_count() > 0, negated) { - (true, false) => { - // has nulls, not negated" - BooleanArray::from_iter( - v.iter().map(|value| Some(self.values.contains(&value?))), - ) + let haystack_has_nulls = self.null_count > 0; + + let result = match (v.null_count() > 0, haystack_has_nulls, negated) { + (true, _, false) | (false, true, false) => { + // Either needle or haystack has nulls, not negated + BooleanArray::from_iter(v.iter().map(|value| match value { + None => None, + Some(v) => { + if self.values.contains(&v) { + Some(true) + } else if haystack_has_nulls { + None + } else { + Some(false) + } + } + })) } - (true, true) => { - // has nulls, negated - BooleanArray::from_iter( - v.iter().map(|value| Some(!self.values.contains(&value?))), - ) + (true, _, true) | (false, true, true) => { + // Either needle or haystack has nulls, negated + BooleanArray::from_iter(v.iter().map(|value| match value { + None => None, + Some(v) => { + if self.values.contains(&v) { + Some(false) + } else if haystack_has_nulls { + None + } else { + Some(true) + } + } + })) } - (false, false) => { - //no null, not negated + (false, false, false) => { + // No nulls anywhere, not negated BooleanArray::from_iter( v.values().iter().map(|value| self.values.contains(value)), ) } - (false, true) => { - // no null, negated + (false, false, true) => { + // No nulls anywhere, negated BooleanArray::from_iter( v.values().iter().map(|value| !self.values.contains(value)), ) From e381c9e7d8c49cd37ca41e3c69d9966b352109f9 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:12:15 +0100 Subject: [PATCH 05/15] remove comment --- datafusion/physical-expr/src/expressions/in_list.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index b343c5d31d5d..65ed1cfcaf2f 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -3481,10 +3481,6 @@ mod tests { #[test] fn test_in_list_dictionary_int32() -> Result<()> { - // Test that Int32StaticFilter handles dictionary-encoded Int32 columns. - // This exposes a bug where as_primitive_opt::() returns None - // for dictionary arrays, causing an error. - // Create schema with dictionary-encoded Int32 column let dict_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)); From 5d632474f8fda239a45dfe430c90da8234b85859 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 3 Dec 2025 18:35:20 +0100 Subject: [PATCH 06/15] lint --- datafusion/physical-expr/src/expressions/in_list.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 65ed1cfcaf2f..a4c68f5283ec 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -632,10 +632,7 @@ pub fn in_list( schema: &Schema, ) -> Result> { Ok(Arc::new(InListExpr::try_from_static_filter( - Arc::clone(&expr), - list.clone(), - *negated, - schema, + expr, list, *negated, schema, )?)) } From b89ed19ab0b5db3980799b552bd849aa76f0c117 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 11:54:22 -0600 Subject: [PATCH 07/15] Add a test harness for primivite types --- .../physical-expr/src/expressions/in_list.rs | 1330 +++-------------- 1 file changed, 213 insertions(+), 1117 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index a4c68f5283ec..f78abd5bdd03 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -403,7 +403,10 @@ impl InListExpr { for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; assert_or_internal_err!( - DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), + DFSchema::datatype_is_logically_equal( + &expr_data_type, + &list_expr_data_type + ), "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); } @@ -422,7 +425,10 @@ impl InListExpr { for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; assert_or_internal_err!( - DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), + DFSchema::datatype_is_logically_equal( + &expr_data_type, + &list_expr_data_type + ), "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); } @@ -639,7 +645,6 @@ pub fn in_list( #[cfg(test)] mod tests { use super::*; - use crate::expressions; use crate::expressions::{col, lit, try_cast}; use arrow::buffer::NullBuffer; use datafusion_common::plan_err; @@ -752,15 +757,48 @@ mod tests { }}; } - #[test] - fn in_list_utf8() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let a = StringArray::from(vec![Some("a"), Some("d"), None]); + /// Test case for primitive types following the standard IN LIST pattern. + /// + /// Each test case represents a data type with: + /// - `value_in`: A value that appears in both the test array and the IN list (matches → true) + /// - `value_not_in`: A value that appears in the test array but NOT in the IN list (doesn't match → false) + /// - `value_in_list`: A value that appears in the IN list but not in the array (filler value) + /// - `null_value`: A null scalar value for NULL handling tests + struct InListPrimitiveTestCase { + name: &'static str, + value_in: ScalarValue, + value_not_in: ScalarValue, + value_in_list: ScalarValue, + null_value: ScalarValue, + } + + /// Runs the standard 4 IN LIST test scenarios for a primitive type. + /// + /// Creates a test array with [Some(value_in), Some(value_not_in), None] and tests: + /// 1. `a IN (value_in, value_in_list)` → `[true, false, null]` + /// 2. `a NOT IN (value_in, value_in_list)` → `[false, true, null]` + /// 3. `a IN (value_in, value_in_list, NULL)` → `[true, null, null]` + /// 4. `a NOT IN (value_in, value_in_list, NULL)` → `[false, null, null]` + fn run_primitive_in_list_test(test_case: InListPrimitiveTestCase) -> Result<()> { + // Get the data type from the scalar value + let data_type = test_case.value_in.data_type(); + let schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + + // Create array from scalar values: [value_in, value_not_in, None] + let array = ScalarValue::iter_to_array(vec![ + test_case.value_in.clone(), + test_case.value_not_in.clone(), + test_case.null_value.clone(), + ])?; + let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![array])?; - // expression: "a in ("a", "b")" - let list = vec![lit("a"), lit("b")]; + // Test 1: a IN (value_in, value_in_list) + let list = vec![ + lit(test_case.value_in.clone()), + lit(test_case.value_in_list.clone()), + ]; in_list!( batch, list, @@ -770,8 +808,11 @@ mod tests { &schema ); - // expression: "a not in ("a", "b")" - let list = vec![lit("a"), lit("b")]; + // Test 2: a NOT IN (value_in, value_in_list) + let list = vec![ + lit(test_case.value_in.clone()), + lit(test_case.value_in_list.clone()), + ]; in_list!( batch, list, @@ -781,8 +822,12 @@ mod tests { &schema ); - // expression: "a in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; + // Test 3: a IN (value_in, value_in_list, NULL) + let list = vec![ + lit(test_case.value_in.clone()), + lit(test_case.value_in_list.clone()), + lit(test_case.null_value.clone()), + ]; in_list!( batch, list, @@ -792,68 +837,12 @@ mod tests { &schema ); - // expression: "a not in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_binary() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]); - let a = BinaryArray::from(vec![ - Some([1, 2, 3].as_slice()), - Some([1, 2, 2].as_slice()), - None, - ]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ([1, 2, 3], [4, 5, 6])" - let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ([1, 2, 3], [4, 5, 6])" - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + // Test 4: a NOT IN (value_in, value_in_list, NULL) let list = vec![ - lit([1, 2, 3].as_slice()), - lit([4, 5, 6].as_slice()), - lit(ScalarValue::Binary(None)), + lit(test_case.value_in), + lit(test_case.value_in_list), + lit(test_case.null_value), ]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in ([1, 2, 3], [4, 5, 6], null)" in_list!( batch, list, @@ -866,56 +855,158 @@ mod tests { Ok(()) } + /// Consolidated test for all primitive types following the standard IN LIST pattern. + /// + /// This test replaces individual test functions for: Int8/16/32/64, UInt8/16/32/64, + /// Utf8, LargeUtf8, Utf8View, Binary, LargeBinary, BinaryView, Date32, Date64, + /// Decimal, and Timestamp types. #[test] - fn in_list_int64() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); - let a = Int64Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0i64), lit(1i64)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0i64), lit(1i64)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); + fn in_list_primitive_types() -> Result<()> { + let test_cases = vec![ + // Signed integers + InListPrimitiveTestCase { + name: "int8", + value_in: ScalarValue::Int8(Some(0)), + value_not_in: ScalarValue::Int8(Some(2)), + value_in_list: ScalarValue::Int8(Some(1)), + null_value: ScalarValue::Int8(None), + }, + InListPrimitiveTestCase { + name: "int16", + value_in: ScalarValue::Int16(Some(0)), + value_not_in: ScalarValue::Int16(Some(2)), + value_in_list: ScalarValue::Int16(Some(1)), + null_value: ScalarValue::Int16(None), + }, + InListPrimitiveTestCase { + name: "int32", + value_in: ScalarValue::Int32(Some(0)), + value_not_in: ScalarValue::Int32(Some(2)), + value_in_list: ScalarValue::Int32(Some(1)), + null_value: ScalarValue::Int32(None), + }, + InListPrimitiveTestCase { + name: "int64", + value_in: ScalarValue::Int64(Some(0)), + value_not_in: ScalarValue::Int64(Some(2)), + value_in_list: ScalarValue::Int64(Some(1)), + null_value: ScalarValue::Int64(None), + }, + // Unsigned integers + InListPrimitiveTestCase { + name: "uint8", + value_in: ScalarValue::UInt8(Some(0)), + value_not_in: ScalarValue::UInt8(Some(2)), + value_in_list: ScalarValue::UInt8(Some(1)), + null_value: ScalarValue::UInt8(None), + }, + InListPrimitiveTestCase { + name: "uint16", + value_in: ScalarValue::UInt16(Some(0)), + value_not_in: ScalarValue::UInt16(Some(2)), + value_in_list: ScalarValue::UInt16(Some(1)), + null_value: ScalarValue::UInt16(None), + }, + InListPrimitiveTestCase { + name: "uint32", + value_in: ScalarValue::UInt32(Some(0)), + value_not_in: ScalarValue::UInt32(Some(2)), + value_in_list: ScalarValue::UInt32(Some(1)), + null_value: ScalarValue::UInt32(None), + }, + InListPrimitiveTestCase { + name: "uint64", + value_in: ScalarValue::UInt64(Some(0)), + value_not_in: ScalarValue::UInt64(Some(2)), + value_in_list: ScalarValue::UInt64(Some(1)), + null_value: ScalarValue::UInt64(None), + }, + // String types + InListPrimitiveTestCase { + name: "utf8", + value_in: ScalarValue::Utf8(Some("a".to_string())), + value_not_in: ScalarValue::Utf8(Some("d".to_string())), + value_in_list: ScalarValue::Utf8(Some("b".to_string())), + null_value: ScalarValue::Utf8(None), + }, + InListPrimitiveTestCase { + name: "large_utf8", + value_in: ScalarValue::LargeUtf8(Some("a".to_string())), + value_not_in: ScalarValue::LargeUtf8(Some("d".to_string())), + value_in_list: ScalarValue::LargeUtf8(Some("b".to_string())), + null_value: ScalarValue::LargeUtf8(None), + }, + InListPrimitiveTestCase { + name: "utf8_view", + value_in: ScalarValue::Utf8View(Some("a".to_string())), + value_not_in: ScalarValue::Utf8View(Some("d".to_string())), + value_in_list: ScalarValue::Utf8View(Some("b".to_string())), + null_value: ScalarValue::Utf8View(None), + }, + // Binary types + InListPrimitiveTestCase { + name: "binary", + value_in: ScalarValue::Binary(Some(vec![1, 2, 3])), + value_not_in: ScalarValue::Binary(Some(vec![1, 2, 2])), + value_in_list: ScalarValue::Binary(Some(vec![4, 5, 6])), + null_value: ScalarValue::Binary(None), + }, + InListPrimitiveTestCase { + name: "large_binary", + value_in: ScalarValue::LargeBinary(Some(vec![1, 2, 3])), + value_not_in: ScalarValue::LargeBinary(Some(vec![1, 2, 2])), + value_in_list: ScalarValue::LargeBinary(Some(vec![4, 5, 6])), + null_value: ScalarValue::LargeBinary(None), + }, + InListPrimitiveTestCase { + name: "binary_view", + value_in: ScalarValue::BinaryView(Some(vec![1, 2, 3])), + value_not_in: ScalarValue::BinaryView(Some(vec![1, 2, 2])), + value_in_list: ScalarValue::BinaryView(Some(vec![4, 5, 6])), + null_value: ScalarValue::BinaryView(None), + }, + // Date types + InListPrimitiveTestCase { + name: "date32", + value_in: ScalarValue::Date32(Some(0)), + value_not_in: ScalarValue::Date32(Some(2)), + value_in_list: ScalarValue::Date32(Some(1)), + null_value: ScalarValue::Date32(None), + }, + InListPrimitiveTestCase { + name: "date64", + value_in: ScalarValue::Date64(Some(0)), + value_not_in: ScalarValue::Date64(Some(2)), + value_in_list: ScalarValue::Date64(Some(1)), + null_value: ScalarValue::Date64(None), + }, + // Decimal type + InListPrimitiveTestCase { + name: "decimal128", + value_in: ScalarValue::Decimal128(Some(0), 10, 2), + value_not_in: ScalarValue::Decimal128(Some(200), 10, 2), + value_in_list: ScalarValue::Decimal128(Some(100), 10, 2), + null_value: ScalarValue::Decimal128(None, 10, 2), + }, + // Timestamp types + InListPrimitiveTestCase { + name: "timestamp_nanosecond", + value_in: ScalarValue::TimestampNanosecond(Some(0), None), + value_not_in: ScalarValue::TimestampNanosecond(Some(2000), None), + value_in_list: ScalarValue::TimestampNanosecond(Some(1000), None), + null_value: ScalarValue::TimestampNanosecond(None, None), + }, + ]; - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); + for test_case in test_cases { + let test_name = test_case.name; + run_primitive_in_list_test(test_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Test failed for type {}: {}", + test_name, e + )) + })?; + } Ok(()) } @@ -1078,1001 +1169,6 @@ mod tests { Ok(()) } - #[test] - fn in_list_int8() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int8, true)]); - let a = Int8Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0i8), lit(1i8)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0i8), lit(1i8)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_int16() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int16, true)]); - let a = Int16Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0i16), lit(1i16)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0i16), lit(1i16)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_int32() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let a = Int32Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0i32), lit(1i32)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0i32), lit(1i32)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_uint8() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::UInt8, true)]); - let a = UInt8Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0u8), lit(1u8)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0u8), lit(1u8)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_uint16() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::UInt16, true)]); - let a = UInt16Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0u16), lit(1u16)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0u16), lit(1u16)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_uint32() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::UInt32, true)]); - let a = UInt32Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0u32), lit(1u32)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0u32), lit(1u32)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_uint64() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::UInt64, true)]); - let a = UInt64Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![lit(0u64), lit(1u64)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![lit(0u64), lit(1u64)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_large_utf8() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::LargeUtf8, true)]); - let a = LargeStringArray::from(vec![Some("a"), Some("d"), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ("a", "b")" - let list = vec![lit("a"), lit("b")]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ("a", "b")" - let list = vec![lit("a"), lit("b")]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_utf8_view() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8View, true)]); - let a = StringViewArray::from(vec![Some("a"), Some("d"), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ("a", "b")" - let list = vec![lit("a"), lit("b")]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ("a", "b")" - let list = vec![lit("a"), lit("b")]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_large_binary() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::LargeBinary, true)]); - let a = LargeBinaryArray::from(vec![ - Some([1, 2, 3].as_slice()), - Some([1, 2, 2].as_slice()), - None, - ]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ([1, 2, 3], [4, 5, 6])" - let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ([1, 2, 3], [4, 5, 6])" - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in ([1, 2, 3], [4, 5, 6], null)" - let list = vec![ - lit([1, 2, 3].as_slice()), - lit([4, 5, 6].as_slice()), - lit(ScalarValue::LargeBinary(None)), - ]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_binary_view() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::BinaryView, true)]); - let a = BinaryViewArray::from(vec![ - Some([1, 2, 3].as_slice()), - Some([1, 2, 2].as_slice()), - None, - ]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ([1, 2, 3], [4, 5, 6])" - let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ([1, 2, 3], [4, 5, 6])" - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in ([1, 2, 3], [4, 5, 6], null)" - let list = vec![ - lit([1, 2, 3].as_slice()), - lit([4, 5, 6].as_slice()), - lit(ScalarValue::BinaryView(None)), - ]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_date64() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]); - let a = Date64Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_date32() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Date32, true)]); - let a = Date32Array::from(vec![Some(0), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0, 1)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn in_list_decimal() -> Result<()> { - // Now, we can check the NULL type - let schema = - Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]); - let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)] - .into_iter() - .collect::(); - let array = array.with_precision_and_scale(13, 4).unwrap(); - let col_a = col("a", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?; - - // expression: "a in (100,200), the data type of list is INT32 - let list = vec![lit(100i32), lit(200i32)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, Some(false)], - Arc::clone(&col_a), - &schema - ); - // expression: "a not in (100,200) - let list = vec![lit(100i32), lit(200i32)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, Some(true)], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (200,NULL), the data type of list is INT32 AND NULL - let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - - // expression: "a in (200.5, 100), the data type of list is FLOAT32 and INT32 - let list = vec![lit(200.50f32), lit(100i32)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, Some(true)], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (200.5, 100), the data type of list is FLOAT32 and INT32 - let list = vec![lit(200.50f32), lit(101i32)]; - in_list!( - batch, - list, - &true, - vec![Some(true), None, Some(false)], - Arc::clone(&col_a), - &schema - ); - - // test the optimization: set - // expression: "a in (99..300), the data type of list is INT32 - let list = (99i32..300).map(lit).collect::>(); - - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), None, Some(false)], - Arc::clone(&col_a), - &schema - ); - - in_list!( - batch, - list, - &true, - vec![Some(false), None, Some(true)], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - - #[test] - fn test_cast_static_filter_to_set() -> Result<()> { - // random schema - let schema = - Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]); - - // list of phy expr - let mut phy_exprs = vec![ - lit(1i64), - expressions::cast(lit(2i32), &schema, DataType::Int64)?, - try_cast(lit(3.13f32), &schema, DataType::Int64)?, - ]; - let static_filter = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - - let array = Int64Array::from(vec![1, 2, 3, 4]); - let r = static_filter.contains(&array, false).unwrap(); - assert_eq!(r, BooleanArray::from(vec![true, true, true, false])); - - try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - // cast(cast(lit())), but the cast to the same data type, one case will be ignored - phy_exprs.push(expressions::cast( - expressions::cast(lit(2i32), &schema, DataType::Int64)?, - &schema, - DataType::Int64, - )?); - try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - - phy_exprs.clear(); - - // case(cast(lit())), the cast to the diff data type - phy_exprs.push(expressions::cast( - expressions::cast(lit(2i32), &schema, DataType::Int64)?, - &schema, - DataType::Int32, - )?); - try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - - // column - phy_exprs.push(col("a", &schema)?); - assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); - - Ok(()) - } - - #[test] - fn in_list_timestamp() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - )]); - let a = TimestampMicrosecondArray::from(vec![ - Some(1388588401000000000), - Some(1288588501000000000), - None, - ]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - let list = vec![ - lit(ScalarValue::TimestampMicrosecond( - Some(1388588401000000000), - None, - )), - lit(ScalarValue::TimestampMicrosecond( - Some(1388588401000000001), - None, - )), - lit(ScalarValue::TimestampMicrosecond( - Some(1388588401000000002), - None, - )), - ]; - - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - - in_list!( - batch, - list.clone(), - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - Ok(()) - } - - #[test] - fn in_expr_with_multiple_element_in_list() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - Field::new("c", DataType::Float64, true), - ]); - let a = Float64Array::from(vec![ - Some(0.0), - Some(1.0), - Some(2.0), - Some(f64::NAN), - Some(-f64::NAN), - ]); - let b = Float64Array::from(vec![ - Some(8.0), - Some(1.0), - Some(5.0), - Some(f64::NAN), - Some(3.0), - ]); - let c = Float64Array::from(vec![ - Some(6.0), - Some(7.0), - None, - Some(5.0), - Some(-f64::NAN), - ]); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_c = col("c", &schema)?; - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(a), Arc::new(b), Arc::new(c)], - )?; - - let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(false), Some(true), None, Some(true), Some(true)], - Arc::clone(&col_a), - &schema - ); - - in_list!( - batch, - list, - &true, - vec![Some(true), Some(false), None, Some(false), Some(false)], - Arc::clone(&col_a), - &schema - ); - - Ok(()) - } - macro_rules! test_nullable { ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{ let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; From 71d747ee9a2fe3ce5ad0b976da82311cb1195eff Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 12:18:25 -0600 Subject: [PATCH 08/15] further consolidate --- .../physical-expr/src/expressions/in_list.rs | 185 +++++++----------- 1 file changed, 73 insertions(+), 112 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index f78abd5bdd03..f8e97ac505b2 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -687,14 +687,6 @@ mod tests { } } - fn try_cast_static_filter_to_set( - list: &[Arc], - schema: &Schema, - ) -> Result { - let array = try_evaluate_constant_list(list, schema)?; - ArrayStaticFilter::try_new(array) - } - // Attempts to coerce the types of `list_type` to be comparable with the // `expr_type` fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option { @@ -772,13 +764,66 @@ mod tests { null_value: ScalarValue, } + /// Helper to create test cases for integer types (signed or unsigned). + /// + /// All integer types use the same test values: 0 (in list), 2 (not in list), 1 (filler). + /// Uses TryFrom which all integer types implement, with unwrap() since 0, 1, 2 always fit. + fn int_test_case(name: &'static str, constructor: F) -> InListPrimitiveTestCase + where + F: Fn(Option) -> ScalarValue, + T: TryFrom, + >::Error: Debug, + { + InListPrimitiveTestCase { + name, + value_in: constructor(Some(T::try_from(0).unwrap())), + value_not_in: constructor(Some(T::try_from(2).unwrap())), + value_in_list: constructor(Some(T::try_from(1).unwrap())), + null_value: constructor(None), + } + } + + /// Helper to create test cases for string types (Utf8, LargeUtf8, Utf8View). + /// + /// All string types use the same test values: "a" (in list), "d" (not in list), "b" (filler). + fn string_test_case( + name: &'static str, + constructor: impl Fn(Option) -> ScalarValue, + ) -> InListPrimitiveTestCase { + InListPrimitiveTestCase { + name, + value_in: constructor(Some("a".to_string())), + value_not_in: constructor(Some("d".to_string())), + value_in_list: constructor(Some("b".to_string())), + null_value: constructor(None), + } + } + + /// Helper to create test cases for binary types (Binary, LargeBinary, BinaryView). + /// + /// All binary types use the same test values: [1,2,3] (in list), [1,2,2] (not in list), [4,5,6] (filler). + fn binary_test_case( + name: &'static str, + constructor: impl Fn(Option>) -> ScalarValue, + ) -> InListPrimitiveTestCase { + InListPrimitiveTestCase { + name, + value_in: constructor(Some(vec![1, 2, 3])), + value_not_in: constructor(Some(vec![1, 2, 2])), + value_in_list: constructor(Some(vec![4, 5, 6])), + null_value: constructor(None), + } + } + /// Runs the standard 4 IN LIST test scenarios for a primitive type. /// /// Creates a test array with [Some(value_in), Some(value_not_in), None] and tests: /// 1. `a IN (value_in, value_in_list)` → `[true, false, null]` /// 2. `a NOT IN (value_in, value_in_list)` → `[false, true, null]` /// 3. `a IN (value_in, value_in_list, NULL)` → `[true, null, null]` - /// 4. `a NOT IN (value_in, value_in_list, NULL)` → `[false, null, null]` + /// 4. `a NOT IN (value_in, value_in_list, NULL)` → `[false, null, null]`\ + /// + /// Where `a` has values `[Some(value_in), Some(value_not_in), None]`. fn run_primitive_in_list_test(test_case: InListPrimitiveTestCase) -> Result<()> { // Get the data type from the scalar value let data_type = test_case.value_in.data_type(); @@ -863,109 +908,25 @@ mod tests { #[test] fn in_list_primitive_types() -> Result<()> { let test_cases = vec![ - // Signed integers - InListPrimitiveTestCase { - name: "int8", - value_in: ScalarValue::Int8(Some(0)), - value_not_in: ScalarValue::Int8(Some(2)), - value_in_list: ScalarValue::Int8(Some(1)), - null_value: ScalarValue::Int8(None), - }, - InListPrimitiveTestCase { - name: "int16", - value_in: ScalarValue::Int16(Some(0)), - value_not_in: ScalarValue::Int16(Some(2)), - value_in_list: ScalarValue::Int16(Some(1)), - null_value: ScalarValue::Int16(None), - }, - InListPrimitiveTestCase { - name: "int32", - value_in: ScalarValue::Int32(Some(0)), - value_not_in: ScalarValue::Int32(Some(2)), - value_in_list: ScalarValue::Int32(Some(1)), - null_value: ScalarValue::Int32(None), - }, - InListPrimitiveTestCase { - name: "int64", - value_in: ScalarValue::Int64(Some(0)), - value_not_in: ScalarValue::Int64(Some(2)), - value_in_list: ScalarValue::Int64(Some(1)), - null_value: ScalarValue::Int64(None), - }, - // Unsigned integers - InListPrimitiveTestCase { - name: "uint8", - value_in: ScalarValue::UInt8(Some(0)), - value_not_in: ScalarValue::UInt8(Some(2)), - value_in_list: ScalarValue::UInt8(Some(1)), - null_value: ScalarValue::UInt8(None), - }, - InListPrimitiveTestCase { - name: "uint16", - value_in: ScalarValue::UInt16(Some(0)), - value_not_in: ScalarValue::UInt16(Some(2)), - value_in_list: ScalarValue::UInt16(Some(1)), - null_value: ScalarValue::UInt16(None), - }, - InListPrimitiveTestCase { - name: "uint32", - value_in: ScalarValue::UInt32(Some(0)), - value_not_in: ScalarValue::UInt32(Some(2)), - value_in_list: ScalarValue::UInt32(Some(1)), - null_value: ScalarValue::UInt32(None), - }, - InListPrimitiveTestCase { - name: "uint64", - value_in: ScalarValue::UInt64(Some(0)), - value_not_in: ScalarValue::UInt64(Some(2)), - value_in_list: ScalarValue::UInt64(Some(1)), - null_value: ScalarValue::UInt64(None), - }, - // String types - InListPrimitiveTestCase { - name: "utf8", - value_in: ScalarValue::Utf8(Some("a".to_string())), - value_not_in: ScalarValue::Utf8(Some("d".to_string())), - value_in_list: ScalarValue::Utf8(Some("b".to_string())), - null_value: ScalarValue::Utf8(None), - }, - InListPrimitiveTestCase { - name: "large_utf8", - value_in: ScalarValue::LargeUtf8(Some("a".to_string())), - value_not_in: ScalarValue::LargeUtf8(Some("d".to_string())), - value_in_list: ScalarValue::LargeUtf8(Some("b".to_string())), - null_value: ScalarValue::LargeUtf8(None), - }, - InListPrimitiveTestCase { - name: "utf8_view", - value_in: ScalarValue::Utf8View(Some("a".to_string())), - value_not_in: ScalarValue::Utf8View(Some("d".to_string())), - value_in_list: ScalarValue::Utf8View(Some("b".to_string())), - null_value: ScalarValue::Utf8View(None), - }, - // Binary types - InListPrimitiveTestCase { - name: "binary", - value_in: ScalarValue::Binary(Some(vec![1, 2, 3])), - value_not_in: ScalarValue::Binary(Some(vec![1, 2, 2])), - value_in_list: ScalarValue::Binary(Some(vec![4, 5, 6])), - null_value: ScalarValue::Binary(None), - }, - InListPrimitiveTestCase { - name: "large_binary", - value_in: ScalarValue::LargeBinary(Some(vec![1, 2, 3])), - value_not_in: ScalarValue::LargeBinary(Some(vec![1, 2, 2])), - value_in_list: ScalarValue::LargeBinary(Some(vec![4, 5, 6])), - null_value: ScalarValue::LargeBinary(None), - }, - InListPrimitiveTestCase { - name: "binary_view", - value_in: ScalarValue::BinaryView(Some(vec![1, 2, 3])), - value_not_in: ScalarValue::BinaryView(Some(vec![1, 2, 2])), - value_in_list: ScalarValue::BinaryView(Some(vec![4, 5, 6])), - null_value: ScalarValue::BinaryView(None), - }, - // Date types + // Signed integers (4 lines instead of 16) + int_test_case("int8", ScalarValue::Int8), + int_test_case("int16", ScalarValue::Int16), + int_test_case("int32", ScalarValue::Int32), + int_test_case("int64", ScalarValue::Int64), + // Unsigned integers (4 lines instead of 16) + int_test_case("uint8", ScalarValue::UInt8), + int_test_case("uint16", ScalarValue::UInt16), + int_test_case("uint32", ScalarValue::UInt32), + int_test_case("uint64", ScalarValue::UInt64), + // String types (3 lines instead of 12) + string_test_case("utf8", ScalarValue::Utf8), + string_test_case("large_utf8", ScalarValue::LargeUtf8), + string_test_case("utf8_view", ScalarValue::Utf8View), + // Binary types (3 lines instead of 12) + binary_test_case("binary", ScalarValue::Binary), + binary_test_case("large_binary", ScalarValue::LargeBinary), + binary_test_case("binary_view", ScalarValue::BinaryView), + // Date types (keep as-is - use different values than integers) InListPrimitiveTestCase { name: "date32", value_in: ScalarValue::Date32(Some(0)), From 1878aaa7dd8366c448abaab135ca6c2a9bef67c7 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 12:56:35 -0600 Subject: [PATCH 09/15] refactor helpers --- .../physical-expr/src/expressions/in_list.rs | 391 ++++++++++-------- 1 file changed, 213 insertions(+), 178 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index f8e97ac505b2..d922cafde2a5 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -764,212 +764,247 @@ mod tests { null_value: ScalarValue, } - /// Helper to create test cases for integer types (signed or unsigned). + /// Generic test data struct for primitive types. /// - /// All integer types use the same test values: 0 (in list), 2 (not in list), 1 (filler). - /// Uses TryFrom which all integer types implement, with unwrap() since 0, 1, 2 always fit. - fn int_test_case(name: &'static str, constructor: F) -> InListPrimitiveTestCase + /// Holds the three test values needed for IN LIST tests, allowing the data + /// to be declared explicitly and reused across multiple types. + #[derive(Clone)] + struct PrimitiveTestCaseData { + value_in: T, + value_not_in: T, + value_in_list: T, + } + + /// Helper to create test cases for any primitive type using generic data. + /// + /// Uses TryInto for flexible type conversion, allowing test data to be + /// declared in any convertible type (e.g., i32 for all integer types). + fn primitive_test_case( + name: &'static str, + constructor: F, + data: PrimitiveTestCaseData, + ) -> InListPrimitiveTestCase where + D: TryInto, + >::Error: Debug, F: Fn(Option) -> ScalarValue, - T: TryFrom, - >::Error: Debug, + T: Clone, { InListPrimitiveTestCase { name, - value_in: constructor(Some(T::try_from(0).unwrap())), - value_not_in: constructor(Some(T::try_from(2).unwrap())), - value_in_list: constructor(Some(T::try_from(1).unwrap())), + value_in: constructor(Some(data.value_in.try_into().unwrap())), + value_not_in: constructor(Some(data.value_not_in.try_into().unwrap())), + value_in_list: constructor(Some(data.value_in_list.try_into().unwrap())), null_value: constructor(None), } } - /// Helper to create test cases for string types (Utf8, LargeUtf8, Utf8View). + /// Runs test cases for multiple types, providing detailed SQL error messages on failure. /// - /// All string types use the same test values: "a" (in list), "d" (not in list), "b" (filler). - fn string_test_case( - name: &'static str, - constructor: impl Fn(Option) -> ScalarValue, - ) -> InListPrimitiveTestCase { - InListPrimitiveTestCase { - name, - value_in: constructor(Some("a".to_string())), - value_not_in: constructor(Some("d".to_string())), - value_in_list: constructor(Some("b".to_string())), - null_value: constructor(None), + /// For each test case, runs 4 standard IN LIST scenarios and provides context + /// about the test data and expected behavior when assertions fail. + fn run_test_cases(test_cases: Vec) -> Result<()> { + for test_case in test_cases { + let test_name = test_case.name; + + // Get the data type from the scalar value + let data_type = test_case.value_in.data_type(); + let schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + + // Create array from scalar values: [value_in, value_not_in, None] + let array = ScalarValue::iter_to_array(vec![ + test_case.value_in.clone(), + test_case.value_not_in.clone(), + test_case.null_value.clone(), + ])?; + + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![array.clone()])?; + + // Helper to format SQL-like representation for error messages + let _format_sql = |negated: bool, with_null: bool| -> String { + let not_str = if negated { "NOT " } else { "" }; + let null_str = if with_null { + format!(", {}", test_case.null_value) + } else { + String::new() + }; + format!( + "Test '{}': a {}IN ({}, {}{})\n where a = [{}, {}, NULL]", + test_name, + not_str, + test_case.value_in, + test_case.value_in_list, + null_str, + test_case.value_in, + test_case.value_not_in + ) + }; + + // Test 1: a IN (value_in, value_in_list) → [true, false, null] + let list = vec![ + lit(test_case.value_in.clone()), + lit(test_case.value_in_list.clone()), + ]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // Test 2: a NOT IN (value_in, value_in_list) → [false, true, null] + let list = vec![ + lit(test_case.value_in.clone()), + lit(test_case.value_in_list.clone()), + ]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // Test 3: a IN (value_in, value_in_list, NULL) → [true, null, null] + let list = vec![ + lit(test_case.value_in.clone()), + lit(test_case.value_in_list.clone()), + lit(test_case.null_value.clone()), + ]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // Test 4: a NOT IN (value_in, value_in_list, NULL) → [false, null, null] + let list = vec![ + lit(test_case.value_in), + lit(test_case.value_in_list), + lit(test_case.null_value), + ]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); } + + Ok(()) } - /// Helper to create test cases for binary types (Binary, LargeBinary, BinaryView). + /// Test IN LIST for all integer types (Int8/16/32/64, UInt8/16/32/64). /// - /// All binary types use the same test values: [1,2,3] (in list), [1,2,2] (not in list), [4,5,6] (filler). - fn binary_test_case( - name: &'static str, - constructor: impl Fn(Option>) -> ScalarValue, - ) -> InListPrimitiveTestCase { - InListPrimitiveTestCase { - name, - value_in: constructor(Some(vec![1, 2, 3])), - value_not_in: constructor(Some(vec![1, 2, 2])), - value_in_list: constructor(Some(vec![4, 5, 6])), - null_value: constructor(None), - } + /// Test data: values 0 (in list), 2 (not in list), 1 (filler) + #[test] + fn in_list_int_types() -> Result<()> { + let int_data = PrimitiveTestCaseData { + value_in: 0, + value_not_in: 2, + value_in_list: 1, + }; + + run_test_cases(vec![ + primitive_test_case("int8", ScalarValue::Int8, int_data.clone()), + primitive_test_case("int16", ScalarValue::Int16, int_data.clone()), + primitive_test_case("int32", ScalarValue::Int32, int_data.clone()), + primitive_test_case("int64", ScalarValue::Int64, int_data.clone()), + primitive_test_case("uint8", ScalarValue::UInt8, int_data.clone()), + primitive_test_case("uint16", ScalarValue::UInt16, int_data.clone()), + primitive_test_case("uint32", ScalarValue::UInt32, int_data.clone()), + primitive_test_case("uint64", ScalarValue::UInt64, int_data), + ]) } - /// Runs the standard 4 IN LIST test scenarios for a primitive type. - /// - /// Creates a test array with [Some(value_in), Some(value_not_in), None] and tests: - /// 1. `a IN (value_in, value_in_list)` → `[true, false, null]` - /// 2. `a NOT IN (value_in, value_in_list)` → `[false, true, null]` - /// 3. `a IN (value_in, value_in_list, NULL)` → `[true, null, null]` - /// 4. `a NOT IN (value_in, value_in_list, NULL)` → `[false, null, null]`\ + /// Test IN LIST for all string types (Utf8, LargeUtf8, Utf8View). /// - /// Where `a` has values `[Some(value_in), Some(value_not_in), None]`. - fn run_primitive_in_list_test(test_case: InListPrimitiveTestCase) -> Result<()> { - // Get the data type from the scalar value - let data_type = test_case.value_in.data_type(); - let schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); - - // Create array from scalar values: [value_in, value_not_in, None] - let array = ScalarValue::iter_to_array(vec![ - test_case.value_in.clone(), - test_case.value_not_in.clone(), - test_case.null_value.clone(), - ])?; - - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![array])?; + /// Test data: "a" (in list), "d" (not in list), "b" (filler) + #[test] + fn in_list_string_types() -> Result<()> { + let string_data = PrimitiveTestCaseData { + value_in: "a", + value_not_in: "d", + value_in_list: "b", + }; - // Test 1: a IN (value_in, value_in_list) - let list = vec![ - lit(test_case.value_in.clone()), - lit(test_case.value_in_list.clone()), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); + run_test_cases(vec![ + primitive_test_case("utf8", ScalarValue::Utf8, string_data.clone()), + primitive_test_case("large_utf8", ScalarValue::LargeUtf8, string_data.clone()), + primitive_test_case("utf8_view", ScalarValue::Utf8View, string_data), + ]) + } - // Test 2: a NOT IN (value_in, value_in_list) - let list = vec![ - lit(test_case.value_in.clone()), - lit(test_case.value_in_list.clone()), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); + /// Test IN LIST for all binary types (Binary, LargeBinary, BinaryView). + /// + /// Test data: [1,2,3] (in list), [1,2,2] (not in list), [4,5,6] (filler) + #[test] + fn in_list_binary_types() -> Result<()> { + let binary_data = PrimitiveTestCaseData { + value_in: vec![1_u8, 2, 3], + value_not_in: vec![1_u8, 2, 2], + value_in_list: vec![4_u8, 5, 6], + }; - // Test 3: a IN (value_in, value_in_list, NULL) - let list = vec![ - lit(test_case.value_in.clone()), - lit(test_case.value_in_list.clone()), - lit(test_case.null_value.clone()), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); + run_test_cases(vec![ + primitive_test_case("binary", ScalarValue::Binary, binary_data.clone()), + primitive_test_case("large_binary", ScalarValue::LargeBinary, binary_data.clone()), + primitive_test_case("binary_view", ScalarValue::BinaryView, binary_data), + ]) + } - // Test 4: a NOT IN (value_in, value_in_list, NULL) - let list = vec![ - lit(test_case.value_in), - lit(test_case.value_in_list), - lit(test_case.null_value), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); + /// Test IN LIST for date types (Date32, Date64). + /// + /// Test data: 0 (in list), 2 (not in list), 1 (filler) + #[test] + fn in_list_date_types() -> Result<()> { + let date_data = PrimitiveTestCaseData { + value_in: 0, + value_not_in: 2, + value_in_list: 1, + }; - Ok(()) + run_test_cases(vec![ + primitive_test_case("date32", ScalarValue::Date32, date_data.clone()), + primitive_test_case("date64", ScalarValue::Date64, date_data), + ]) } - /// Consolidated test for all primitive types following the standard IN LIST pattern. + /// Test IN LIST for Decimal128 type. /// - /// This test replaces individual test functions for: Int8/16/32/64, UInt8/16/32/64, - /// Utf8, LargeUtf8, Utf8View, Binary, LargeBinary, BinaryView, Date32, Date64, - /// Decimal, and Timestamp types. + /// Test data: 0 (in list), 200 (not in list), 100 (filler) with precision=10, scale=2 #[test] - fn in_list_primitive_types() -> Result<()> { - let test_cases = vec![ - // Signed integers (4 lines instead of 16) - int_test_case("int8", ScalarValue::Int8), - int_test_case("int16", ScalarValue::Int16), - int_test_case("int32", ScalarValue::Int32), - int_test_case("int64", ScalarValue::Int64), - // Unsigned integers (4 lines instead of 16) - int_test_case("uint8", ScalarValue::UInt8), - int_test_case("uint16", ScalarValue::UInt16), - int_test_case("uint32", ScalarValue::UInt32), - int_test_case("uint64", ScalarValue::UInt64), - // String types (3 lines instead of 12) - string_test_case("utf8", ScalarValue::Utf8), - string_test_case("large_utf8", ScalarValue::LargeUtf8), - string_test_case("utf8_view", ScalarValue::Utf8View), - // Binary types (3 lines instead of 12) - binary_test_case("binary", ScalarValue::Binary), - binary_test_case("large_binary", ScalarValue::LargeBinary), - binary_test_case("binary_view", ScalarValue::BinaryView), - // Date types (keep as-is - use different values than integers) - InListPrimitiveTestCase { - name: "date32", - value_in: ScalarValue::Date32(Some(0)), - value_not_in: ScalarValue::Date32(Some(2)), - value_in_list: ScalarValue::Date32(Some(1)), - null_value: ScalarValue::Date32(None), - }, - InListPrimitiveTestCase { - name: "date64", - value_in: ScalarValue::Date64(Some(0)), - value_not_in: ScalarValue::Date64(Some(2)), - value_in_list: ScalarValue::Date64(Some(1)), - null_value: ScalarValue::Date64(None), - }, - // Decimal type - InListPrimitiveTestCase { - name: "decimal128", - value_in: ScalarValue::Decimal128(Some(0), 10, 2), - value_not_in: ScalarValue::Decimal128(Some(200), 10, 2), - value_in_list: ScalarValue::Decimal128(Some(100), 10, 2), - null_value: ScalarValue::Decimal128(None, 10, 2), - }, - // Timestamp types - InListPrimitiveTestCase { - name: "timestamp_nanosecond", - value_in: ScalarValue::TimestampNanosecond(Some(0), None), - value_not_in: ScalarValue::TimestampNanosecond(Some(2000), None), - value_in_list: ScalarValue::TimestampNanosecond(Some(1000), None), - null_value: ScalarValue::TimestampNanosecond(None, None), - }, - ]; - - for test_case in test_cases { - let test_name = test_case.name; - run_primitive_in_list_test(test_case).map_err(|e| { - datafusion_common::DataFusionError::Execution(format!( - "Test failed for type {}: {}", - test_name, e - )) - })?; - } + fn in_list_decimal() -> Result<()> { + run_test_cases(vec![InListPrimitiveTestCase { + name: "decimal128", + value_in: ScalarValue::Decimal128(Some(0), 10, 2), + value_not_in: ScalarValue::Decimal128(Some(200), 10, 2), + value_in_list: ScalarValue::Decimal128(Some(100), 10, 2), + null_value: ScalarValue::Decimal128(None, 10, 2), + }]) + } - Ok(()) + /// Test IN LIST for timestamp types. + /// + /// Test data: 0 (in list), 2000 (not in list), 1000 (filler) + #[test] + fn in_list_timestamp_types() -> Result<()> { + run_test_cases(vec![InListPrimitiveTestCase { + name: "timestamp_nanosecond", + value_in: ScalarValue::TimestampNanosecond(Some(0), None), + value_not_in: ScalarValue::TimestampNanosecond(Some(2000), None), + value_in_list: ScalarValue::TimestampNanosecond(Some(1000), None), + null_value: ScalarValue::TimestampNanosecond(None, None), + }]) } #[test] From 0ef15f225398956f3ce2133e1b7d2458bd2951e0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 13:10:04 -0600 Subject: [PATCH 10/15] add dictionary tests --- .../physical-expr/src/expressions/in_list.rs | 268 +++++++++++++++++- 1 file changed, 265 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index d922cafde2a5..919342398b72 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -819,7 +819,8 @@ mod tests { ])?; let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![array.clone()])?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![array.clone()])?; // Helper to format SQL-like representation for error messages let _format_sql = |negated: bool, with_null: bool| -> String { @@ -939,7 +940,11 @@ mod tests { run_test_cases(vec![ primitive_test_case("utf8", ScalarValue::Utf8, string_data.clone()), - primitive_test_case("large_utf8", ScalarValue::LargeUtf8, string_data.clone()), + primitive_test_case( + "large_utf8", + ScalarValue::LargeUtf8, + string_data.clone(), + ), primitive_test_case("utf8_view", ScalarValue::Utf8View, string_data), ]) } @@ -957,7 +962,11 @@ mod tests { run_test_cases(vec![ primitive_test_case("binary", ScalarValue::Binary, binary_data.clone()), - primitive_test_case("large_binary", ScalarValue::LargeBinary, binary_data.clone()), + primitive_test_case( + "large_binary", + ScalarValue::LargeBinary, + binary_data.clone(), + ), primitive_test_case("binary_view", ScalarValue::BinaryView, binary_data), ]) } @@ -2596,4 +2605,257 @@ mod tests { assert_eq!(result, &BooleanArray::from(vec![true, true, false])); Ok(()) } + + #[test] + fn test_in_list_dictionary_types() -> Result<()> { + // Helper functions for creating dictionary literals + fn dict_lit_int64(key_type: DataType, value: i64) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::Int64(Some(value))), + )) + } + + fn dict_lit_float64(key_type: DataType, value: f64) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::Float64(Some(value))), + )) + } + + // Test case structures + struct DictNeedleTest { + list_values: Vec>, + expected: Vec>, + } + + struct DictionaryInListTestCase { + _name: &'static str, + dict_type: DataType, + dict_keys: Vec>, + dict_values: ArrayRef, + list_values_no_null: Vec>, + list_values_with_null: Vec>, + expected_1: Vec>, + expected_2: Vec>, + expected_3: Vec>, + expected_4: Vec>, + dict_needle_test: Option, + } + + // Test harness function + fn run_dictionary_in_list_test( + test_case: DictionaryInListTestCase, + ) -> Result<()> { + // Create schema with dictionary type + let schema = + Schema::new(vec![Field::new("a", test_case.dict_type.clone(), true)]); + let col_a = col("a", &schema)?; + + // Create dictionary array from keys and values + let keys = Int8Array::from(test_case.dict_keys.clone()); + let dict_array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, test_case.dict_values)?); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![dict_array])?; + + let exp1 = test_case.expected_1.clone(); + let exp2 = test_case.expected_2.clone(); + let exp3 = test_case.expected_3.clone(); + let exp4 = test_case.expected_4; + + // Test 1: a IN (values_no_null) + in_list!( + batch, + test_case.list_values_no_null.clone(), + &false, + exp1, + Arc::clone(&col_a), + &schema + ); + + // Test 2: a NOT IN (values_no_null) + in_list!( + batch, + test_case.list_values_no_null.clone(), + &true, + exp2, + Arc::clone(&col_a), + &schema + ); + + // Test 3: a IN (values_with_null) + in_list!( + batch, + test_case.list_values_with_null.clone(), + &false, + exp3, + Arc::clone(&col_a), + &schema + ); + + // Test 4: a NOT IN (values_with_null) + in_list!( + batch, + test_case.list_values_with_null, + &true, + exp4, + Arc::clone(&col_a), + &schema + ); + + // Optional: Dictionary needle test (if provided) + if let Some(needle_test) = test_case.dict_needle_test { + in_list_raw!( + batch, + needle_test.list_values, + &false, + needle_test.expected, + Arc::clone(&col_a), + &schema + ); + } + + Ok(()) + } + + // Test case 1: UTF8 + // Dictionary: keys [0, 1, null] → values ["a", "d", -] + // Rows: ["a", "d", null] + let utf8_case = DictionaryInListTestCase { + _name: "dictionary_utf8", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + dict_keys: vec![Some(0), Some(1), None], + dict_values: Arc::new(StringArray::from(vec![Some("a"), Some("d")])), + list_values_no_null: vec![lit("a"), lit("b")], + list_values_with_null: vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], + expected_1: vec![Some(true), Some(false), None], + expected_2: vec![Some(false), Some(true), None], + expected_3: vec![Some(true), None, None], + expected_4: vec![Some(false), None, None], + dict_needle_test: None, + }; + + // Test case 2: Int64 with dictionary needles + // Dictionary: keys [0, 1, null] → values [10, 20, -] + // Rows: [10, 20, null] + let int64_case = DictionaryInListTestCase { + _name: "dictionary_int64", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Int64), + ), + dict_keys: vec![Some(0), Some(1), None], + dict_values: Arc::new(Int64Array::from(vec![Some(10), Some(20)])), + list_values_no_null: vec![lit(10i64), lit(15i64)], + list_values_with_null: vec![ + lit(10i64), + lit(15i64), + lit(ScalarValue::Int64(None)), + ], + expected_1: vec![Some(true), Some(false), None], + expected_2: vec![Some(false), Some(true), None], + expected_3: vec![Some(true), None, None], + expected_4: vec![Some(false), None, None], + dict_needle_test: Some(DictNeedleTest { + list_values: vec![ + dict_lit_int64(DataType::Int16, 10), + dict_lit_int64(DataType::Int16, 15), + ], + expected: vec![Some(true), Some(false), None], + }), + }; + + // Test case 3: Float64 with NaN and dictionary needles + // Dictionary: keys [0, 1, null, 2] → values [1.5, 3.7, NaN, -] + // Rows: [1.5, 3.7, null, NaN] + // Note: NaN is a value (not null), so it goes in the values array + let float64_case = DictionaryInListTestCase { + _name: "dictionary_float64", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Float64), + ), + dict_keys: vec![Some(0), Some(1), None, Some(2)], + dict_values: Arc::new(Float64Array::from(vec![ + Some(1.5), // index 0 + Some(3.7), // index 1 + Some(f64::NAN), // index 2 + ])), + list_values_no_null: vec![lit(1.5f64), lit(2.0f64)], + list_values_with_null: vec![ + lit(1.5f64), + lit(2.0f64), + lit(ScalarValue::Float64(None)), + ], + // Test 1: a IN (1.5, 2.0) → [true, false, null, false] + // NaN is false because NaN not in list and no NULL in list + expected_1: vec![Some(true), Some(false), None, Some(false)], + // Test 2: a NOT IN (1.5, 2.0) → [false, true, null, true] + // NaN is true because NaN not in list + expected_2: vec![Some(false), Some(true), None, Some(true)], + // Test 3: a IN (1.5, 2.0, NULL) → [true, null, null, null] + // 3.7 and NaN become null due to NULL in list (three-valued logic) + expected_3: vec![Some(true), None, None, None], + // Test 4: a NOT IN (1.5, 2.0, NULL) → [false, null, null, null] + // 3.7 and NaN become null due to NULL in list + expected_4: vec![Some(false), None, None, None], + dict_needle_test: Some(DictNeedleTest { + list_values: vec![ + dict_lit_float64(DataType::UInt16, 1.5), + dict_lit_float64(DataType::UInt16, 2.0), + ], + expected: vec![Some(true), Some(false), None, Some(false)], + }), + }; + + // Execute all test cases + run_dictionary_in_list_test(utf8_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test failed for UTF8: {}", + e + )) + })?; + + run_dictionary_in_list_test(int64_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test failed for Int64: {}", + e + )) + })?; + + run_dictionary_in_list_test(float64_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test failed for Float64: {}", + e + )) + })?; + + // Additional test for Float64 NaN in IN list + let dict_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Float64)); + let schema = Schema::new(vec![Field::new("a", dict_type.clone(), true)]); + let col_a = col("a", &schema)?; + + let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]); + let values = Float64Array::from(vec![Some(1.5), Some(3.7), Some(f64::NAN)]); + let dict_array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values))?); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![dict_array])?; + + // Test: a IN (1.5, 2.0, NaN) + let list_with_nan = vec![lit(1.5f64), lit(2.0f64), lit(f64::NAN)]; + in_list!( + batch, + list_with_nan, + &false, + vec![Some(true), Some(false), None, Some(true)], + col_a, + &schema + ); + + Ok(()) + } } From 488b16f71788c15cca4934550bef91836e3f0b84 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 13:14:10 -0600 Subject: [PATCH 11/15] more dict test cases --- .../physical-expr/src/expressions/in_list.rs | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 919342398b72..f79f8c3f0e8e 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -2833,6 +2833,41 @@ mod tests { )) })?; + // Additional test: Dictionary deduplication with repeated keys + // This tests that multiple rows with the same key (pointing to the same value) + // are evaluated correctly + let dedup_case = DictionaryInListTestCase { + _name: "dictionary_deduplication", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + // Keys: [0, 1, 0, 1, null] - keys 0 and 1 are repeated + // This creates data: ["a", "d", "a", "d", null] + dict_keys: vec![Some(0), Some(1), Some(0), Some(1), None], + dict_values: Arc::new(StringArray::from(vec![Some("a"), Some("d")])), + list_values_no_null: vec![lit("a"), lit("b")], + list_values_with_null: vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], + // Test 1: a IN ("a", "b") → [true, false, true, false, null] + // Rows 0 and 2 both have key 0 → "a", so both are true + expected_1: vec![Some(true), Some(false), Some(true), Some(false), None], + // Test 2: a NOT IN ("a", "b") → [false, true, false, true, null] + expected_2: vec![Some(false), Some(true), Some(false), Some(true), None], + // Test 3: a IN ("a", "b", NULL) → [true, null, true, null, null] + // "d" becomes null due to NULL in list + expected_3: vec![Some(true), None, Some(true), None, None], + // Test 4: a NOT IN ("a", "b", NULL) → [false, null, false, null, null] + expected_4: vec![Some(false), None, Some(false), None, None], + dict_needle_test: None, + }; + + run_dictionary_in_list_test(dedup_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary deduplication test failed: {}", + e + )) + })?; + // Additional test for Float64 NaN in IN list let dict_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Float64)); From f13c308e4af17a01be114d7280e3c28e07206e8b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 13:17:56 -0600 Subject: [PATCH 12/15] include test name in errors --- .../physical-expr/src/expressions/in_list.rs | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index f79f8c3f0e8e..590786b264b3 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -2630,7 +2630,7 @@ mod tests { } struct DictionaryInListTestCase { - _name: &'static str, + name: &'static str, dict_type: DataType, dict_keys: Vec>, dict_values: ArrayRef, @@ -2722,7 +2722,7 @@ mod tests { // Dictionary: keys [0, 1, null] → values ["a", "d", -] // Rows: ["a", "d", null] let utf8_case = DictionaryInListTestCase { - _name: "dictionary_utf8", + name: "dictionary_utf8", dict_type: DataType::Dictionary( Box::new(DataType::Int8), Box::new(DataType::Utf8), @@ -2742,7 +2742,7 @@ mod tests { // Dictionary: keys [0, 1, null] → values [10, 20, -] // Rows: [10, 20, null] let int64_case = DictionaryInListTestCase { - _name: "dictionary_int64", + name: "dictionary_int64", dict_type: DataType::Dictionary( Box::new(DataType::Int8), Box::new(DataType::Int64), @@ -2773,7 +2773,7 @@ mod tests { // Rows: [1.5, 3.7, null, NaN] // Note: NaN is a value (not null), so it goes in the values array let float64_case = DictionaryInListTestCase { - _name: "dictionary_float64", + name: "dictionary_float64", dict_type: DataType::Dictionary( Box::new(DataType::Int8), Box::new(DataType::Float64), @@ -2812,24 +2812,27 @@ mod tests { }; // Execute all test cases + let test_name = utf8_case.name; run_dictionary_in_list_test(utf8_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test failed for UTF8: {}", - e + "Dictionary test '{}' failed: {}", + test_name, e )) })?; + let test_name = int64_case.name; run_dictionary_in_list_test(int64_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test failed for Int64: {}", - e + "Dictionary test '{}' failed: {}", + test_name, e )) })?; + let test_name = float64_case.name; run_dictionary_in_list_test(float64_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test failed for Float64: {}", - e + "Dictionary test '{}' failed: {}", + test_name, e )) })?; @@ -2837,7 +2840,7 @@ mod tests { // This tests that multiple rows with the same key (pointing to the same value) // are evaluated correctly let dedup_case = DictionaryInListTestCase { - _name: "dictionary_deduplication", + name: "dictionary_deduplication", dict_type: DataType::Dictionary( Box::new(DataType::Int8), Box::new(DataType::Utf8), @@ -2861,10 +2864,11 @@ mod tests { dict_needle_test: None, }; + let test_name = dedup_case.name; run_dictionary_in_list_test(dedup_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary deduplication test failed: {}", - e + "Dictionary test '{}' failed: {}", + test_name, e )) })?; From 7a8b4ce153c128e9acf25b5f1b6a8c2b07b3b817 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Dec 2025 13:30:29 -0600 Subject: [PATCH 13/15] lint, apply https://github.com/apache/datafusion/pull/18832#discussion_r2597474781 --- .../physical-expr/src/expressions/in_list.rs | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 590786b264b3..bd05fbbe05dd 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -241,9 +241,10 @@ impl StaticFilter for Int32StaticFilter { .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; let haystack_has_nulls = self.null_count > 0; + let has_nulls = v.null_count() > 0 || haystack_has_nulls; - let result = match (v.null_count() > 0, haystack_has_nulls, negated) { - (true, _, false) | (false, true, false) => { + let result = match (has_nulls, negated) { + (true, false) => { // Either needle or haystack has nulls, not negated BooleanArray::from_iter(v.iter().map(|value| match value { None => None, @@ -258,7 +259,7 @@ impl StaticFilter for Int32StaticFilter { } })) } - (true, _, true) | (false, true, true) => { + (true, true) => { // Either needle or haystack has nulls, negated BooleanArray::from_iter(v.iter().map(|value| match value { None => None, @@ -273,13 +274,13 @@ impl StaticFilter for Int32StaticFilter { } })) } - (false, false, false) => { + (false, false) => { // No nulls anywhere, not negated BooleanArray::from_iter( v.values().iter().map(|value| self.values.contains(value)), ) } - (false, false, true) => { + (false, true) => { // No nulls anywhere, negated BooleanArray::from_iter( v.values().iter().map(|value| !self.values.contains(value)), @@ -820,7 +821,7 @@ mod tests { let col_a = col("a", &schema)?; let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![array.clone()])?; + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::clone(&array)])?; // Helper to format SQL-like representation for error messages let _format_sql = |negated: bool, with_null: bool| -> String { @@ -2815,24 +2816,21 @@ mod tests { let test_name = utf8_case.name; run_dictionary_in_list_test(utf8_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test '{}' failed: {}", - test_name, e + "Dictionary test '{test_name}' failed: {e}" )) })?; let test_name = int64_case.name; run_dictionary_in_list_test(int64_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test '{}' failed: {}", - test_name, e + "Dictionary test '{test_name}' failed: {e}" )) })?; let test_name = float64_case.name; run_dictionary_in_list_test(float64_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test '{}' failed: {}", - test_name, e + "Dictionary test '{test_name}' failed: {e}" )) })?; @@ -2867,8 +2865,7 @@ mod tests { let test_name = dedup_case.name; run_dictionary_in_list_test(dedup_case).map_err(|e| { datafusion_common::DataFusionError::Execution(format!( - "Dictionary test '{}' failed: {}", - test_name, e + "Dictionary test '{test_name}' failed: {e}" )) })?; From 89234b1af2c86651929f09d0cab2834a8595f961 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 9 Dec 2025 09:02:40 -0600 Subject: [PATCH 14/15] add null testing, address pr feedback --- .../physical-expr/src/expressions/in_list.rs | 390 +++++++++++------- 1 file changed, 233 insertions(+), 157 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index bd05fbbe05dd..daa8b263c701 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -245,7 +245,7 @@ impl StaticFilter for Int32StaticFilter { let result = match (has_nulls, negated) { (true, false) => { - // Either needle or haystack has nulls, not negated + // needle has nulls, not negated BooleanArray::from_iter(v.iter().map(|value| match value { None => None, Some(v) => { @@ -260,7 +260,7 @@ impl StaticFilter for Int32StaticFilter { })) } (true, true) => { - // Either needle or haystack has nulls, negated + // needle has nulls, negated BooleanArray::from_iter(v.iter().map(|value| match value { None => None, Some(v) => { @@ -315,15 +315,26 @@ fn evaluate_list( /// Try to evaluate a list of expressions as constants. /// -/// Returns an ArrayRef if all expressions are constants (can be evaluated on an -/// empty RecordBatch), otherwise returns an error. This is used to detect when -/// a list contains only literals, casts of literals, or other constant expressions. +/// Returns: +/// - `Ok(Some(ArrayRef))` if all expressions are constants (can be evaluated on an empty RecordBatch) +/// - `Ok(None)` if the list contains non-constant expressions +/// - `Err(...)` only for actual errors (not for non-constant expressions) +/// +/// This is used to detect when a list contains only literals, casts of literals, +/// or other constant expressions. fn try_evaluate_constant_list( list: &[Arc], schema: &Schema, -) -> Result { +) -> Result> { let batch = RecordBatch::new_empty(Arc::new(schema.clone())); - evaluate_list(list, &batch) + match evaluate_list(list, &batch) { + Ok(array) => Ok(Some(array)), + Err(_) => { + // Non-constant expressions can't be evaluated on an empty batch + // This is not an error, just means we can't use a static filter + Ok(None) + } + } } impl InListExpr { @@ -387,13 +398,15 @@ impl InListExpr { )) } - /// Create a new InList expression with a static filter for constant list expressions. + /// Create a new InList expression, using a static filter when possible. /// - /// This validates data types, evaluates the list as constants, and uses specialized - /// StaticFilter implementations for better performance (e.g., Int32StaticFilter for Int32). + /// This validates data types and attempts to create a static filter for constant + /// list expressions. Uses specialized StaticFilter implementations for better + /// performance (e.g., Int32StaticFilter for Int32). /// - /// Returns an error if data types don't match or if the list contains non-constant expressions. - pub fn try_from_static_filter( + /// Returns an error if data types don't match. If the list contains non-constant + /// expressions, falls back to dynamic evaluation at runtime. + pub fn try_new( expr: Arc, list: Vec>, negated: bool, @@ -412,30 +425,13 @@ impl InListExpr { ); } - match try_evaluate_constant_list(&list, schema) { - Ok(in_array) => Ok(Self::new( - expr, - list, - negated, - Some(instantiate_static_filter(in_array)?), - )), - Err(_) => { - // Fall back to non-static filter if list contains non-constant expressions - // Still need to validate types - let expr_data_type = expr.data_type(schema)?; - for list_expr in list.iter() { - let list_expr_data_type = list_expr.data_type(schema)?; - assert_or_internal_err!( - DFSchema::datatype_is_logically_equal( - &expr_data_type, - &list_expr_data_type - ), - "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" - ); - } - Ok(Self::new(expr, list, negated, None)) - } - } + // Try to create a static filter if all list expressions are constants + let static_filter = match try_evaluate_constant_list(&list, schema)? { + Some(in_array) => Some(instantiate_static_filter(in_array)?), + None => None, // Non-constant expressions, fall back to dynamic evaluation + }; + + Ok(Self::new(expr, list, negated, static_filter)) } } impl std::fmt::Display for InListExpr { @@ -638,9 +634,7 @@ pub fn in_list( negated: &bool, schema: &Schema, ) -> Result> { - Ok(Arc::new(InListExpr::try_from_static_filter( - expr, list, *negated, schema, - )?)) + Ok(Arc::new(InListExpr::try_new(expr, list, *negated, schema)?)) } #[cfg(test)] @@ -739,14 +733,25 @@ mod tests { /// and list expressions are already the correct types and don't require casting. macro_rules! in_list_raw { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ - let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap(); + let col_expr = $COL; + let expr = in_list(Arc::clone(&col_expr), $LIST, $NEGATED, $SCHEMA).unwrap(); let result = expr .evaluate(&$BATCH)? .into_array($BATCH.num_rows()) .expect("Failed to convert to array"); let result = as_boolean_array(&result); let expected = &BooleanArray::from($EXPECTED); - assert_eq!(expected, result); + assert_eq!( + expected, + result, + "Failed for: {}\n{}: {:?}", + fmt_sql(expr.as_ref()), + fmt_sql(col_expr.as_ref()), + col_expr + .evaluate(&$BATCH)? + .into_array($BATCH.num_rows()) + .unwrap() + ); }}; } @@ -755,38 +760,40 @@ mod tests { /// Each test case represents a data type with: /// - `value_in`: A value that appears in both the test array and the IN list (matches → true) /// - `value_not_in`: A value that appears in the test array but NOT in the IN list (doesn't match → false) - /// - `value_in_list`: A value that appears in the IN list but not in the array (filler value) - /// - `null_value`: A null scalar value for NULL handling tests + /// - `other_list_values`: Additional values in the IN list besides `value_in` + /// - `null_value`: Optional null scalar value for NULL handling tests. When None, tests + /// without nulls are run, exercising the `(false, false)` and `(false, true)` branches. struct InListPrimitiveTestCase { name: &'static str, value_in: ScalarValue, value_not_in: ScalarValue, - value_in_list: ScalarValue, - null_value: ScalarValue, + other_list_values: Vec, + null_value: Option, } /// Generic test data struct for primitive types. /// - /// Holds the three test values needed for IN LIST tests, allowing the data + /// Holds test values needed for IN LIST tests, allowing the data /// to be declared explicitly and reused across multiple types. #[derive(Clone)] struct PrimitiveTestCaseData { value_in: T, value_not_in: T, - value_in_list: T, + other_list_values: Vec, } /// Helper to create test cases for any primitive type using generic data. /// /// Uses TryInto for flexible type conversion, allowing test data to be /// declared in any convertible type (e.g., i32 for all integer types). + /// Creates a test case WITH null support (for null handling tests). fn primitive_test_case( name: &'static str, constructor: F, data: PrimitiveTestCaseData, ) -> InListPrimitiveTestCase where - D: TryInto, + D: TryInto + Clone, >::Error: Debug, F: Fn(Option) -> ScalarValue, T: Clone, @@ -795,111 +802,170 @@ mod tests { name, value_in: constructor(Some(data.value_in.try_into().unwrap())), value_not_in: constructor(Some(data.value_not_in.try_into().unwrap())), - value_in_list: constructor(Some(data.value_in_list.try_into().unwrap())), - null_value: constructor(None), + other_list_values: data + .other_list_values + .into_iter() + .map(|v| constructor(Some(v.try_into().unwrap()))) + .collect(), + null_value: Some(constructor(None)), + } + } + + /// Helper to create test cases WITHOUT null support. + /// These test cases exercise the `(false, true)` branch (no nulls, negated). + fn primitive_test_case_no_nulls( + name: &'static str, + constructor: F, + data: PrimitiveTestCaseData, + ) -> InListPrimitiveTestCase + where + D: TryInto + Clone, + >::Error: Debug, + F: Fn(Option) -> ScalarValue, + T: Clone, + { + InListPrimitiveTestCase { + name, + value_in: constructor(Some(data.value_in.try_into().unwrap())), + value_not_in: constructor(Some(data.value_not_in.try_into().unwrap())), + other_list_values: data + .other_list_values + .into_iter() + .map(|v| constructor(Some(v.try_into().unwrap()))) + .collect(), + null_value: None, } } /// Runs test cases for multiple types, providing detailed SQL error messages on failure. /// - /// For each test case, runs 4 standard IN LIST scenarios and provides context - /// about the test data and expected behavior when assertions fail. + /// For each test case, runs IN LIST scenarios based on whether null_value is Some or None: + /// - With null_value (Some): 4 tests including null handling + /// - Without null_value (None): 2 tests exercising the no-nulls paths fn run_test_cases(test_cases: Vec) -> Result<()> { for test_case in test_cases { let test_name = test_case.name; // Get the data type from the scalar value let data_type = test_case.value_in.data_type(); - let schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); - - // Create array from scalar values: [value_in, value_not_in, None] - let array = ScalarValue::iter_to_array(vec![ - test_case.value_in.clone(), - test_case.value_not_in.clone(), - test_case.null_value.clone(), - ])?; - let col_a = col("a", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::clone(&array)])?; - - // Helper to format SQL-like representation for error messages - let _format_sql = |negated: bool, with_null: bool| -> String { - let not_str = if negated { "NOT " } else { "" }; - let null_str = if with_null { - format!(", {}", test_case.null_value) - } else { - String::new() - }; - format!( - "Test '{}': a {}IN ({}, {}{})\n where a = [{}, {}, NULL]", - test_name, - not_str, - test_case.value_in, - test_case.value_in_list, - null_str, - test_case.value_in, - test_case.value_not_in - ) + // Build the base list: [value_in, ...other_list_values] + let build_base_list = || -> Vec> { + let mut list = vec![lit(test_case.value_in.clone())]; + list.extend(test_case.other_list_values.iter().map(|v| lit(v.clone()))); + list }; - // Test 1: a IN (value_in, value_in_list) → [true, false, null] - let list = vec![ - lit(test_case.value_in.clone()), - lit(test_case.value_in_list.clone()), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); + match &test_case.null_value { + Some(null_val) => { + // Tests WITH nulls in the needle array + let schema = + Schema::new(vec![Field::new("a", data_type.clone(), true)]); + + // Create array from scalar values: [value_in, value_not_in, None] + let array = ScalarValue::iter_to_array(vec![ + test_case.value_in.clone(), + test_case.value_not_in.clone(), + null_val.clone(), + ])?; + + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::clone(&array)], + )?; + + // Test 1: a IN (list) → [true, false, null] + let list = build_base_list(); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); - // Test 2: a NOT IN (value_in, value_in_list) → [false, true, null] - let list = vec![ - lit(test_case.value_in.clone()), - lit(test_case.value_in_list.clone()), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); + // Test 2: a NOT IN (list) → [false, true, null] + let list = build_base_list(); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); - // Test 3: a IN (value_in, value_in_list, NULL) → [true, null, null] - let list = vec![ - lit(test_case.value_in.clone()), - lit(test_case.value_in_list.clone()), - lit(test_case.null_value.clone()), - ]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); + // Test 3: a IN (list, NULL) → [true, null, null] + let mut list = build_base_list(); + list.push(lit(null_val.clone())); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); - // Test 4: a NOT IN (value_in, value_in_list, NULL) → [false, null, null] - let list = vec![ - lit(test_case.value_in), - lit(test_case.value_in_list), - lit(test_case.null_value), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); + // Test 4: a NOT IN (list, NULL) → [false, null, null] + let mut list = build_base_list(); + list.push(lit(null_val.clone())); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + } + None => { + // Tests WITHOUT nulls - exercises the (false, false) and (false, true) branches + let schema = + Schema::new(vec![Field::new("a", data_type.clone(), false)]); + + // Create array from scalar values: [value_in, value_not_in] (no NULL) + let array = ScalarValue::iter_to_array(vec![ + test_case.value_in.clone(), + test_case.value_not_in.clone(), + ])?; + + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::clone(&array)], + )?; + + // Test 1: a IN (list) → [true, false] - exercises (false, false) branch + let list = build_base_list(); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false)], + Arc::clone(&col_a), + &schema + ); + + // Test 2: a NOT IN (list) → [false, true] - exercises (false, true) branch + let list = build_base_list(); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true)], + Arc::clone(&col_a), + &schema + ); + + eprintln!( + "Test '{}': exercised (false, true) branch (no nulls, negated)", + test_name + ); + } + } } Ok(()) @@ -907,16 +973,17 @@ mod tests { /// Test IN LIST for all integer types (Int8/16/32/64, UInt8/16/32/64). /// - /// Test data: values 0 (in list), 2 (not in list), 1 (filler) + /// Test data: 0 (in list), 2 (not in list), [1, 3, 5] (other list values) #[test] fn in_list_int_types() -> Result<()> { let int_data = PrimitiveTestCaseData { value_in: 0, value_not_in: 2, - value_in_list: 1, + other_list_values: vec![1, 3, 5], }; run_test_cases(vec![ + // Tests WITH nulls primitive_test_case("int8", ScalarValue::Int8, int_data.clone()), primitive_test_case("int16", ScalarValue::Int16, int_data.clone()), primitive_test_case("int32", ScalarValue::Int32, int_data.clone()), @@ -924,19 +991,21 @@ mod tests { primitive_test_case("uint8", ScalarValue::UInt8, int_data.clone()), primitive_test_case("uint16", ScalarValue::UInt16, int_data.clone()), primitive_test_case("uint32", ScalarValue::UInt32, int_data.clone()), - primitive_test_case("uint64", ScalarValue::UInt64, int_data), + primitive_test_case("uint64", ScalarValue::UInt64, int_data.clone()), + // Tests WITHOUT nulls - exercises (false, true) branch + primitive_test_case_no_nulls("int32_no_nulls", ScalarValue::Int32, int_data), ]) } /// Test IN LIST for all string types (Utf8, LargeUtf8, Utf8View). /// - /// Test data: "a" (in list), "d" (not in list), "b" (filler) + /// Test data: "a" (in list), "d" (not in list), ["b", "c"] (other list values) #[test] fn in_list_string_types() -> Result<()> { let string_data = PrimitiveTestCaseData { value_in: "a", value_not_in: "d", - value_in_list: "b", + other_list_values: vec!["b", "c"], }; run_test_cases(vec![ @@ -952,13 +1021,13 @@ mod tests { /// Test IN LIST for all binary types (Binary, LargeBinary, BinaryView). /// - /// Test data: [1,2,3] (in list), [1,2,2] (not in list), [4,5,6] (filler) + /// Test data: [1,2,3] (in list), [1,2,2] (not in list), [[4,5,6], [7,8,9]] (other list values) #[test] fn in_list_binary_types() -> Result<()> { let binary_data = PrimitiveTestCaseData { value_in: vec![1_u8, 2, 3], value_not_in: vec![1_u8, 2, 2], - value_in_list: vec![4_u8, 5, 6], + other_list_values: vec![vec![4_u8, 5, 6], vec![7_u8, 8, 9]], }; run_test_cases(vec![ @@ -974,13 +1043,13 @@ mod tests { /// Test IN LIST for date types (Date32, Date64). /// - /// Test data: 0 (in list), 2 (not in list), 1 (filler) + /// Test data: 0 (in list), 2 (not in list), [1, 3] (other list values) #[test] fn in_list_date_types() -> Result<()> { let date_data = PrimitiveTestCaseData { value_in: 0, value_not_in: 2, - value_in_list: 1, + other_list_values: vec![1, 3], }; run_test_cases(vec![ @@ -991,29 +1060,35 @@ mod tests { /// Test IN LIST for Decimal128 type. /// - /// Test data: 0 (in list), 200 (not in list), 100 (filler) with precision=10, scale=2 + /// Test data: 0 (in list), 200 (not in list), [100, 300] (other list values) with precision=10, scale=2 #[test] fn in_list_decimal() -> Result<()> { run_test_cases(vec![InListPrimitiveTestCase { name: "decimal128", value_in: ScalarValue::Decimal128(Some(0), 10, 2), value_not_in: ScalarValue::Decimal128(Some(200), 10, 2), - value_in_list: ScalarValue::Decimal128(Some(100), 10, 2), - null_value: ScalarValue::Decimal128(None, 10, 2), + other_list_values: vec![ + ScalarValue::Decimal128(Some(100), 10, 2), + ScalarValue::Decimal128(Some(300), 10, 2), + ], + null_value: Some(ScalarValue::Decimal128(None, 10, 2)), }]) } /// Test IN LIST for timestamp types. /// - /// Test data: 0 (in list), 2000 (not in list), 1000 (filler) + /// Test data: 0 (in list), 2000 (not in list), [1000, 3000] (other list values) #[test] fn in_list_timestamp_types() -> Result<()> { run_test_cases(vec![InListPrimitiveTestCase { name: "timestamp_nanosecond", value_in: ScalarValue::TimestampNanosecond(Some(0), None), value_not_in: ScalarValue::TimestampNanosecond(Some(2000), None), - value_in_list: ScalarValue::TimestampNanosecond(Some(1000), None), - null_value: ScalarValue::TimestampNanosecond(None, None), + other_list_values: vec![ + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(3000), None), + ], + null_value: Some(ScalarValue::TimestampNanosecond(None, None)), }]) } @@ -2586,21 +2661,22 @@ mod tests { let schema = Schema::new(vec![Field::new("a", dict_type.clone(), false)]); let col_a = col("a", &schema)?; - // Create IN list with Int32 literals: (1, 2, 3) - let list = vec![lit(1i32), lit(2i32), lit(3i32)]; + // Create IN list with Int32 literals: (100, 200, 300) + let list = vec![lit(100i32), lit(200i32), lit(300i32)]; // Create InListExpr via in_list() - this uses Int32StaticFilter for Int32 lists let expr = in_list(col_a, list, &false, &schema)?; - // Create dictionary-encoded batch with values [1, 2, 5] - // Dictionary: keys [0, 1, 2] -> values [1, 2, 5] + // Create dictionary-encoded batch with values [100, 200, 500] + // Dictionary: keys [0, 1, 2] -> values [100, 200, 500] + // Using values clearly distinct from keys to avoid confusion let keys = Int8Array::from(vec![0, 1, 2]); - let values = Int32Array::from(vec![1, 2, 5]); + let values = Int32Array::from(vec![100, 200, 500]); let dict_array: ArrayRef = Arc::new(DictionaryArray::try_new(keys, Arc::new(values))?); let batch = RecordBatch::try_new(Arc::new(schema), vec![dict_array])?; - // Expected: [1 IN (1,2,3), 2 IN (1,2,3), 5 IN (1,2,3)] = [true, true, false] + // Expected: [100 IN (100,200,300), 200 IN (100,200,300), 500 IN (100,200,300)] = [true, true, false] let result = expr.evaluate(&batch)?.into_array(3)?; let result = as_boolean_array(&result); assert_eq!(result, &BooleanArray::from(vec![true, true, false])); From 6a499ae78f625dedf48af4eeb21e679327c7056d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 9 Dec 2025 09:14:57 -0600 Subject: [PATCH 15/15] fix lint --- datafusion/physical-expr/src/expressions/in_list.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index daa8b263c701..c1f23ed2aed5 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -961,8 +961,7 @@ mod tests { ); eprintln!( - "Test '{}': exercised (false, true) branch (no nulls, negated)", - test_name + "Test '{test_name}': exercised (false, true) branch (no nulls, negated)", ); } }