diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index fba8f4fbe4d9..b6b67c85c488 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; -use arrow::buffer::BooleanBuffer; +use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::{SortOptions, take}; use arrow::datatypes::*; @@ -91,7 +91,11 @@ impl StaticFilter for ArrayStaticFilter { if v.data_type() == &DataType::Null || self.in_array.data_type() == &DataType::Null { - return Ok(BooleanArray::from(vec![None; v.len()])); + let nulls = NullBuffer::new_null(v.len()); + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(nulls), + )); } downcast_dictionary_array! { @@ -138,9 +142,20 @@ fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { match in_array.data_type() { + // Integer primitive types + DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + // Float primitive types (use ordered wrappers for Hash/Eq) + DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), _ => { - /* fall through to generic implementation */ + /* fall through to generic implementation for unsupported types (Struct, etc.) */ Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) } } @@ -198,99 +213,325 @@ impl ArrayStaticFilter { } } -struct Int32StaticFilter { - null_count: usize, - values: HashSet, +/// Wrapper for f32 that implements Hash and Eq using bit comparison. +/// This treats NaN values as equal to each other when they have the same bit pattern. +#[derive(Clone, Copy)] +struct OrderedFloat32(f32); + +impl Hash for OrderedFloat32 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat32 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } } -impl Int32StaticFilter { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; +impl Eq for OrderedFloat32 {} - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); +impl From for OrderedFloat32 { + fn from(v: f32) -> Self { + Self(v) + } +} - for v in in_array.iter().flatten() { - values.insert(v); - } +/// Wrapper for f64 that implements Hash and Eq using bit comparison. +/// This treats NaN values as equal to each other when they have the same bit pattern. +#[derive(Clone, Copy)] +struct OrderedFloat64(f64); - Ok(Self { null_count, values }) +impl Hash for OrderedFloat64 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); } } -impl StaticFilter for Int32StaticFilter { - fn null_count(&self) -> usize { - self.null_count +impl PartialEq for OrderedFloat64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() } +} - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) +impl Eq for OrderedFloat64 {} + +impl From for OrderedFloat64 { + fn from(v: f64) -> Self { + Self(v) + } +} + +// Macro to generate specialized StaticFilter implementations for primitive types +macro_rules! primitive_static_filter { + ($Name:ident, $ArrowType:ty) => { + struct $Name { + null_count: usize, + values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, + } + + impl $Name { + fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(v); + } + + Ok(Self { null_count, values }) } - _ => {} } - let v = v - .as_primitive_opt::() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - - let haystack_has_nulls = self.null_count > 0; - let has_nulls = v.null_count() > 0 || haystack_has_nulls; - - let result = match (has_nulls, negated) { - (true, false) => { - // needle has nulls, not negated - BooleanArray::from_iter(v.iter().map(|value| match value { - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(true) - } else if haystack_has_nulls { - None + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let haystack_has_nulls = self.null_count > 0; + + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: + // ("-" means the value doesn't affect the result) + // + // | needle_null | haystack_null | negated | in set? | result | + // |-------------|---------------|---------|---------|--------| + // | true | - | false | - | null | + // | true | - | true | - | null | + // | false | true | false | yes | true | + // | false | true | false | no | null | + // | false | true | true | yes | false | + // | false | true | true | no | null | + // | false | false | false | yes | true | + // | false | false | false | no | false | + // | false | false | true | yes | false | + // | false | false | true | no | true | + + // Compute the "contains" result using collect_bool (fast batched approach) + // This ignores nulls - we handle them separately + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&needle_values[i]) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&needle_values[i]) + }) + }; + + // Compute the null mask + // Output is null when: + // 1. needle value is null, OR + // 2. needle value is not in set AND haystack has nulls + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => { + // No nulls anywhere + None + } + (true, false) => { + // Only needle has nulls - just use needle's null mask + needle_nulls.cloned() + } + (false, true) => { + // Only haystack has nulls - result is null when value not in set + // Valid (not null) when original "in set" is true + // For NOT IN: contains_buffer = !original, so validity = !contains_buffer + let validity = if negated { + !&contains_buffer } else { - Some(false) - } + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) } - })) - } - (true, true) => { - // needle has nulls, negated - BooleanArray::from_iter(v.iter().map(|value| match value { - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(false) - } else if haystack_has_nulls { - None + (true, true) => { + // Both have nulls - combine needle nulls with haystack-induced nulls + let needle_validity = needle_nulls.map(|n| n.inner().clone()) + .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + + // Valid when original "in set" is true (see above) + let haystack_validity = if negated { + !&contains_buffer } else { - Some(true) - } + contains_buffer.clone() + }; + + // Combined validity: valid only where both are valid + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) } - })) + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) } - (false, false) => { - // No nulls anywhere, not negated - BooleanArray::from_iter( - v.values().iter().map(|value| self.values.contains(value)), - ) + } + }; +} + +// Generate specialized filters for all integer primitive types +primitive_static_filter!(Int8StaticFilter, Int8Type); +primitive_static_filter!(Int16StaticFilter, Int16Type); +primitive_static_filter!(Int32StaticFilter, Int32Type); +primitive_static_filter!(Int64StaticFilter, Int64Type); +primitive_static_filter!(UInt8StaticFilter, UInt8Type); +primitive_static_filter!(UInt16StaticFilter, UInt16Type); +primitive_static_filter!(UInt32StaticFilter, UInt32Type); +primitive_static_filter!(UInt64StaticFilter, UInt64Type); + +// Macro to generate specialized StaticFilter implementations for float types +// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics +macro_rules! float_static_filter { + ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { + struct $Name { + null_count: usize, + values: HashSet<$OrderedType>, + } + + impl $Name { + fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(<$OrderedType>::from(v)); + } + + Ok(Self { null_count, values }) } - (false, true) => { - // No nulls anywhere, negated - BooleanArray::from_iter( - v.values().iter().map(|value| !self.values.contains(value)), - ) + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count } - }; - Ok(result) - } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let haystack_has_nulls = self.null_count > 0; + + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: + // ("-" means the value doesn't affect the result) + // + // | needle_null | haystack_null | negated | in set? | result | + // |-------------|---------------|---------|---------|--------| + // | true | - | false | - | null | + // | true | - | true | - | null | + // | false | true | false | yes | true | + // | false | true | false | no | null | + // | false | true | true | yes | false | + // | false | true | true | no | null | + // | false | false | false | yes | true | + // | false | false | false | no | false | + // | false | false | true | yes | false | + // | false | false | true | no | true | + + // Compute the "contains" result using collect_bool (fast batched approach) + // This ignores nulls - we handle them separately + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&<$OrderedType>::from(needle_values[i])) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&<$OrderedType>::from(needle_values[i])) + }) + }; + + // Compute the null mask + // Output is null when: + // 1. needle value is null, OR + // 2. needle value is not in set AND haystack has nulls + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => { + // No nulls anywhere + None + } + (true, false) => { + // Only needle has nulls - just use needle's null mask + needle_nulls.cloned() + } + (false, true) => { + // Only haystack has nulls - result is null when value not in set + // Valid (not null) when original "in set" is true + // For NOT IN: contains_buffer = !original, so validity = !contains_buffer + let validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) + } + (true, true) => { + // Both have nulls - combine needle nulls with haystack-induced nulls + let needle_validity = needle_nulls.map(|n| n.inner().clone()) + .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + + // Valid when original "in set" is true (see above) + let haystack_validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + + // Combined validity: valid only where both are valid + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) + } + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) + } + } + }; } +// Generate specialized filters for float types using ordered wrappers +float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); +float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); + /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], @@ -500,8 +741,12 @@ impl PhysicalExpr for InListExpr { if scalar.is_null() { // SQL three-valued logic: null IN (...) is always null // The code below would handle this correctly but this is a faster path + let nulls = NullBuffer::new_null(num_rows); return Ok(ColumnarValue::Array(Arc::new( - BooleanArray::from(vec![None; num_rows]), + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ), ))); } // Use a 1 row array to avoid code duplication/branching @@ -512,12 +757,15 @@ impl PhysicalExpr for InListExpr { // Broadcast the single result to all rows // Must check is_null() to preserve NULL values (SQL three-valued logic) if result_array.is_null(0) { - BooleanArray::from(vec![None; num_rows]) + let nulls = NullBuffer::new_null(num_rows); + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ) + } else if result_array.value(0) { + BooleanArray::new(BooleanBuffer::new_set(num_rows), None) } else { - BooleanArray::from_iter(std::iter::repeat_n( - result_array.value(0), - num_rows, - )) + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None) } } }