From b57f4937138056301c174c6414e9fc6ca08d699c Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Mon, 8 Dec 2025 10:32:03 +0100 Subject: [PATCH 1/4] fix: inverted null_percent logic in in_list benchmark --- datafusion/physical-expr/benches/in_list.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 778204055bbd..705fafb253ed 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -49,10 +49,12 @@ fn do_benches( null_percent: f64, ) { let mut rng = StdRng::seed_from_u64(120320); + let non_null_percent = 1.0 - null_percent; + for string_length in [5, 10, 20] { let values: StringArray = (0..array_length) .map(|_| { - rng.random_bool(null_percent) + rng.random_bool(non_null_percent) .then(|| random_string(&mut rng, string_length)) }) .collect(); @@ -72,7 +74,7 @@ fn do_benches( } let values: Float32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) + .map(|_| rng.random_bool(non_null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) @@ -87,7 +89,7 @@ fn do_benches( ); let values: Int32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) + .map(|_| rng.random_bool(non_null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) From f1f064b85b18eee42ccbca516d092a367a0c2d60 Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Mon, 8 Dec 2025 13:09:51 +0100 Subject: [PATCH 2/4] bench: add Utf8 and LargeUtf8 benchmarks for InList - Add LargeStringArray benchmarks alongside existing StringArray benchmarks - Use explicit ScalarValue::Utf8 for StringArray (was using ScalarValue::from which creates Utf8View) --- datafusion/physical-expr/benches/in_list.rs | 175 +++++++++++++------- 1 file changed, 114 insertions(+), 61 deletions(-) diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 705fafb253ed..664bc2341074 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Float32Array, Int32Array, StringArray}; +use arrow::array::{ + Array, ArrayRef, Float32Array, Int32Array, StringArray, StringViewArray, +}; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use criterion::{criterion_group, criterion_main, Criterion}; @@ -23,9 +25,11 @@ use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; use rand::distr::Alphanumeric; use rand::prelude::*; +use std::any::TypeId; use std::hint::black_box; use std::sync::Arc; +/// Measures how long `in_list(col("a"), exprs)` takes to evaluate against a single RecordBatch. fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValue]) { let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); let exprs = exprs.iter().map(|s| lit(s.clone())).collect(); @@ -37,79 +41,128 @@ fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValu }); } +/// Generates a random alphanumeric string of the specified length. fn random_string(rng: &mut StdRng, len: usize) -> String { let value = rng.sample_iter(&Alphanumeric).take(len).collect(); String::from_utf8(value).unwrap() } -fn do_benches( - c: &mut Criterion, - array_length: usize, - in_list_length: usize, - null_percent: f64, -) { - let mut rng = StdRng::seed_from_u64(120320); - let non_null_percent = 1.0 - null_percent; - - for string_length in [5, 10, 20] { - let values: StringArray = (0..array_length) - .map(|_| { - rng.random_bool(non_null_percent) - .then(|| random_string(&mut rng, string_length)) - }) - .collect(); - - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) - .collect(); - - do_bench( - c, - &format!( - "in_list_utf8({string_length}) ({array_length}, {null_percent}) IN ({in_list_length}, 0)" - ), - Arc::new(values), - &in_list, - ) +const IN_LIST_LENGTHS: [usize; 3] = [3, 8, 100]; +const NULL_PERCENTS: [f64; 2] = [0., 0.2]; +const STRING_LENGTHS: [usize; 3] = [3, 12, 100]; +const ARRAY_LENGTH: usize = 1024; + +/// Returns a friendly type name for the array type. +fn array_type_name() -> &'static str { + let id = TypeId::of::(); + if id == TypeId::of::() { + "Utf8" + } else if id == TypeId::of::() { + "Utf8View" + } else if id == TypeId::of::() { + "Float32" + } else if id == TypeId::of::() { + "Int32" + } else { + "Unknown" } +} - let values: Float32Array = (0..array_length) - .map(|_| rng.random_bool(non_null_percent).then(|| rng.random())) - .collect(); +/// Builds a benchmark name from array type, list size, and null percentage. +fn bench_name(in_list_length: usize, null_percent: f64) -> String { + format!( + "in_list/{}/list={in_list_length}/nulls={}%", + array_type_name::(), + (null_percent * 100.0) as u32 + ) +} - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.random()))) - .collect(); +/// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations. +fn bench_string_type( + c: &mut Criterion, + rng: &mut StdRng, + make_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + for string_length in STRING_LENGTHS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| { + rng.random_bool(1.0 - null_percent) + .then(|| random_string(rng, string_length)) + }) + .collect(); + + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(random_string(rng, string_length))) + .collect(); + + do_bench( + c, + &format!( + "{}/str={string_length}", + bench_name::(in_list_length, null_percent) + ), + Arc::new(values), + &in_list, + ) + } + } + } +} - do_bench( - c, - &format!("in_list_f32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ); +/// Runs in_list benchmarks for a numeric array type across all list-size × null-ratio combinations. +fn bench_numeric_type( + c: &mut Criterion, + rng: &mut StdRng, + mut gen_value: impl FnMut(&mut StdRng) -> T, + make_scalar: fn(T) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng))) + .collect(); + + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(gen_value(rng))) + .collect(); + + do_bench( + c, + &bench_name::(in_list_length, null_percent), + Arc::new(values), + &in_list, + ); + } + } +} - let values: Int32Array = (0..array_length) - .map(|_| rng.random_bool(non_null_percent).then(|| rng.random())) - .collect(); +/// Entry point: registers in_list benchmarks for Utf8, Utf8View, Float32, and Int32 arrays. +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(120320); - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.random()))) - .collect(); + // Benchmarks for string array types (Utf8, Utf8View) + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8(Some(s))); + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8View(Some(s))); - do_bench( + // Benchmarks for numeric types + bench_numeric_type::( c, - &format!("in_list_i32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ) -} - -fn criterion_benchmark(c: &mut Criterion) { - for in_list_length in [1, 3, 10, 100] { - for null_percent in [0., 0.2] { - do_benches(c, 1024, in_list_length, null_percent) - } - } + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Float32(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ); } criterion_group!(benches, criterion_benchmark); From 5b85900d986179e4608500712a849e45c4c1cf36 Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Mon, 8 Dec 2025 11:21:12 +0100 Subject: [PATCH 3/4] perf(in_list): replace iterator-based result building with vectorized collect_bool The previous implementation used BooleanArray::from_iter and BooleanBufferBuilder with element-by-element appends, which incur iterator overhead and prevent vectorization. This commit switches to BooleanBuffer::collect_bool, a batch operation that pre-allocates the exact buffer size and enables SIMD optimization. Since collect_bool guarantees the index is always in bounds, we can safely use unchecked array access (value_unchecked, get_unchecked) to eliminate bounds checks in the hot loop. The null-handling match is also simplified from a 3-way tuple to a 2-way check by pre-combining needle and haystack null flags. --- .../physical-expr/src/expressions/in_list.rs | 97 +++++++++---------- 1 file changed, 46 insertions(+), 51 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 51daa073efa1..c750d7876440 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -257,62 +257,57 @@ macro_rules! primitive_static_filter { .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; let haystack_has_nulls = self.null_count > 0; - - let result = match (v.null_count() > 0, haystack_has_nulls, negated) { - (true, _, false) | (false, true, false) => { - // Either needle or haystack has nulls, not negated - BooleanArray::from_iter(v.iter().map(|value| { - match value { - // SQL three-valued logic: null IN (...) is always null - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(true) - } else if haystack_has_nulls { - // value not in set, but set has nulls -> null - None - } else { - Some(false) - } - } - } - })) + let has_nulls = v.null_count() > 0 || haystack_has_nulls; + + // SAFETY: collect_bool guarantees i < len for all closure calls + let result = match (has_nulls, negated) { + (true, false) => { + // Has nulls somewhere, not negated + let len = v.len(); + let values_buf = BooleanBuffer::collect_bool(len, |i| { + // SAFETY: i < len is guaranteed by collect_bool + // If found in set -> true, otherwise false (null handled by validity) + v.is_valid(i) && self.values.contains(unsafe { &v.value_unchecked(i) }) + }); + let nulls_buf = BooleanBuffer::collect_bool(len, |i| { + // SAFETY: i < len is guaranteed by collect_bool + // Valid (not null) if: needle is valid AND (found OR haystack has no nulls) + v.is_valid(i) && (self.values.contains(unsafe { &v.value_unchecked(i) }) || !haystack_has_nulls) + }); + BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) } - (true, _, true) | (false, true, true) => { - // Either needle or haystack has nulls, negated - BooleanArray::from_iter(v.iter().map(|value| { - match value { - // SQL three-valued logic: null NOT IN (...) is always null - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(false) - } else if haystack_has_nulls { - // value not in set, but set has nulls -> null - None - } else { - Some(true) - } - } - } - })) + (true, true) => { + // Has nulls somewhere, negated + let len = v.len(); + let values_buf = BooleanBuffer::collect_bool(len, |i| { + // SAFETY: i < len is guaranteed by collect_bool + // If found in set -> false, otherwise true (null handled by validity) + v.is_valid(i) && !self.values.contains(unsafe { &v.value_unchecked(i) }) + }); + let nulls_buf = BooleanBuffer::collect_bool(len, |i| { + // SAFETY: i < len is guaranteed by collect_bool + // Valid (not null) if: needle is valid AND (found OR haystack has no nulls) + v.is_valid(i) && (self.values.contains(unsafe { &v.value_unchecked(i) }) || !haystack_has_nulls) + }); + BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) } - (false, false, false) => { - // no nulls anywhere, not negated + (false, false) => { + // No nulls anywhere, not negated let values = v.values(); - let mut builder = BooleanBufferBuilder::new(values.len()); - for value in values.iter() { - builder.append(self.values.contains(value)); - } - BooleanArray::new(builder.finish(), None) + let values_buf = BooleanBuffer::collect_bool(values.len(), |i| { + // SAFETY: i < len is guaranteed by collect_bool + self.values.contains(unsafe { values.get_unchecked(i) }) + }); + BooleanArray::new(values_buf, None) } - (false, false, true) => { + (false, true) => { + // No nulls anywhere, negated let values = v.values(); - let mut builder = BooleanBufferBuilder::new(values.len()); - for value in values.iter() { - builder.append(!self.values.contains(value)); - } - BooleanArray::new(builder.finish(), None) + let values_buf = BooleanBuffer::collect_bool(values.len(), |i| { + // SAFETY: i < len is guaranteed by collect_bool + !self.values.contains(unsafe { values.get_unchecked(i) }) + }); + BooleanArray::new(values_buf, None) } }; Ok(result) From 398dc8f160d892a5998b8e81368819578145d667 Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Mon, 8 Dec 2025 12:04:41 +0100 Subject: [PATCH 4/4] perf(in_list): optimize lookup for small lists and Utf8View short strings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For small IN lists (≤8 elements), hashing overhead dominates execution time. This commit uses binary search instead, which is faster for small lists. Utf8View gains a short-string filter that compares raw u128 views directly - the same layout Arrow uses for inline storage (≤12 bytes). This turns string comparison into fast integer comparison. Lists with long strings fall through to the generic hash-based filter. Benchmarks show significant improvement for Utf8View short strings and primitives with small lists. --- .../physical-expr/src/expressions/in_list.rs | 538 ++++++++++++++---- 1 file changed, 435 insertions(+), 103 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index c750d7876440..78eaf6fadf4b 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -141,21 +141,58 @@ impl StaticFilter for ArrayStaticFilter { fn instantiate_static_filter( in_array: ArrayRef, +) -> Result> { + if in_array.len() <= SORTED_LOOKUP_MAX_LEN { + instantiate_sorted_filter(in_array) + } else { + instantiate_hashed_filter(in_array) + } +} + +/// Sorted filter using binary search. Best for small lists (≤8 elements). +fn instantiate_sorted_filter( + in_array: ArrayRef, ) -> Result> { match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } + DataType::Int8 => Ok(Arc::new(Int8SortedFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16SortedFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32SortedFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64SortedFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8SortedFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16SortedFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32SortedFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64SortedFilter::try_new(&in_array)?)), + DataType::Float32 => Ok(Arc::new(Float32SortedFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64SortedFilter::try_new(&in_array)?)), + DataType::Utf8View => match Utf8ViewSortedFilter::try_new(&in_array) { + Some(Ok(filter)) => Ok(Arc::new(filter)), + Some(Err(e)) => Err(e), + None => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + }, + _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + } +} + +/// Hashed filter using HashSet. Best for larger lists (>8 elements). +fn instantiate_hashed_filter( + in_array: ArrayRef, +) -> Result> { + match in_array.data_type() { + DataType::Int8 => Ok(Arc::new(Int8HashedFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16HashedFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32HashedFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64HashedFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8HashedFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16HashedFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32HashedFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64HashedFilter::try_new(&in_array)?)), + // Floats don't implement Hash, fall through to generic + DataType::Utf8View => match Utf8ViewHashedFilter::try_new(&in_array) { + Some(Ok(filter)) => Ok(Arc::new(filter)), + Some(Err(e)) => Err(e), + None => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + }, + _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), } } @@ -211,26 +248,184 @@ impl ArrayStaticFilter { } } -// Macro to generate specialized StaticFilter implementations for primitive types -macro_rules! primitive_static_filter { - ($Name:ident, $ArrowType:ty) => { +/// Threshold for switching from sorted Vec (binary search) to HashSet +/// For small lists, binary search has better cache locality and lower overhead +/// Maximum list size for using sorted lookup (binary search). +/// Lists with more elements use hash lookup instead. +const SORTED_LOOKUP_MAX_LEN: usize = 8; + +/// Helper to build a BooleanArray result for IN list operations. +/// Handles SQL three-valued logic for NULL values. +/// +/// # Arguments +/// * `len` - Number of elements in the needle array +/// * `needle_nulls` - Optional validity buffer from the needle array +/// * `haystack_has_nulls` - Whether the IN list contains NULL values +/// * `negated` - Whether this is a NOT IN operation +/// * `contains` - Closure that returns whether needle[i] is found in the haystack +#[inline] +fn build_in_list_result( + len: usize, + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains: C, +) -> BooleanArray +where + C: Fn(usize) -> bool, +{ + // Use collect_bool for all paths - it's vectorized and faster than element-by-element append. + // Match on (needle_has_nulls, haystack_has_nulls, negated) to specialize each case. + match (needle_nulls, haystack_has_nulls, negated) { + // Haystack has nulls: result is NULL when not found (might match the NULL) + // values_buf == nulls_buf, so compute once and clone + (Some(validity), true, false) => { + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let buf = validity.inner() & &contains_buf; + BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) + } + (None, true, false) => { + let buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) + } + (Some(validity), true, true) => { + // Compute nulls_buf via SIMD AND, then derive values_buf via XOR. + // Uses identity: A & !B = A ^ (A & B) to get values from nulls. + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let nulls_buf = validity.inner() & &contains_buf; + let values_buf = validity.inner() ^ &nulls_buf; // valid & !contains + BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) + } + (None, true, true) => { + // No needle nulls, but haystack has nulls + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let values_buf = !&contains_buf; + BooleanArray::new(values_buf, Some(NullBuffer::new(contains_buf))) + } + // Only needle has nulls: nulls_buf is just validity (reuse it directly!) + (Some(validity), false, false) => { + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let values_buf = validity.inner() & &contains_buf; + BooleanArray::new(values_buf, Some(validity.clone())) + } + (Some(validity), false, true) => { + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let values_buf = validity.inner() & &(!&contains_buf); + BooleanArray::new(values_buf, Some(validity.clone())) + } + // No nulls anywhere: no validity buffer needed + (None, false, false) => { + let buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + BooleanArray::new(buf, None) + } + (None, false, true) => { + let buf = BooleanBuffer::collect_bool(len, |i| !contains(i)); + BooleanArray::new(buf, None) + } + } +} + +/// Sorted lookup using binary search. Best for small lists (< 8 elements). +struct SortedLookup(Vec); + +impl SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable(); + values.dedup(); + Self(values) + } + + #[inline] + fn contains(&self, value: &T) -> bool { + self.0.binary_search(value).is_ok() + } +} + +/// Sorted lookup for f32 using total_cmp (floats don't implement Ord due to NaN). +struct F32SortedLookup(Vec); + +impl F32SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable_by(|a, b| a.total_cmp(b)); + values.dedup_by(|a, b| a.total_cmp(b).is_eq()); + Self(values) + } + + #[inline] + fn contains(&self, value: &f32) -> bool { + self.0 + .binary_search_by(|probe| probe.total_cmp(value)) + .is_ok() + } +} + +/// Sorted lookup for f64 using total_cmp (floats don't implement Ord due to NaN). +struct F64SortedLookup(Vec); + +impl F64SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable_by(|a, b| a.total_cmp(b)); + values.dedup_by(|a, b| a.total_cmp(b).is_eq()); + Self(values) + } + + #[inline] + fn contains(&self, value: &f64) -> bool { + self.0 + .binary_search_by(|probe| probe.total_cmp(value)) + .is_ok() + } +} + +/// Hash-based lookup. Best for larger lists (>= 8 elements). +struct HashedLookup(HashSet); + +impl HashedLookup { + fn new(values: Vec) -> Self { + Self(values.into_iter().collect()) + } + + #[inline] + fn contains(&self, value: &T) -> bool { + self.0.contains(value) + } +} + +/// Helper macro for dictionary array handling in StaticFilter::contains +/// This pattern is the same across all filter implementations +macro_rules! handle_dictionary { + ($self:ident, $v:ident, $negated:ident) => { + downcast_dictionary_array! { + $v => { + let values_contains = $self.contains($v.values().as_ref(), $negated)?; + let result = take(&values_contains, $v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + }; +} + +// Base macro to generate sorted StaticFilter with explicit lookup type. +macro_rules! sorted_static_filter_impl { + ($Name:ident, $ArrowType:ty, $LookupType:ty) => { struct $Name { null_count: usize, - values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, + values: $LookupType, } impl $Name { fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + let in_array = + in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; - let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } + let values = <$LookupType>::new(in_array.iter().flatten().collect()); Ok(Self { null_count, values }) } @@ -242,90 +437,227 @@ macro_rules! primitive_static_filter { } fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } + handle_dictionary!(self, v, negated); + + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + // SAFETY: i < len is guaranteed by build_in_list_result + |i| self.values.contains(unsafe { values.get_unchecked(i) }), + )) + } + } + }; +} - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - let has_nulls = v.null_count() > 0 || haystack_has_nulls; - - // SAFETY: collect_bool guarantees i < len for all closure calls - let result = match (has_nulls, negated) { - (true, false) => { - // Has nulls somewhere, not negated - let len = v.len(); - let values_buf = BooleanBuffer::collect_bool(len, |i| { - // SAFETY: i < len is guaranteed by collect_bool - // If found in set -> true, otherwise false (null handled by validity) - v.is_valid(i) && self.values.contains(unsafe { &v.value_unchecked(i) }) - }); - let nulls_buf = BooleanBuffer::collect_bool(len, |i| { - // SAFETY: i < len is guaranteed by collect_bool - // Valid (not null) if: needle is valid AND (found OR haystack has no nulls) - v.is_valid(i) && (self.values.contains(unsafe { &v.value_unchecked(i) }) || !haystack_has_nulls) - }); - BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) - } - (true, true) => { - // Has nulls somewhere, negated - let len = v.len(); - let values_buf = BooleanBuffer::collect_bool(len, |i| { - // SAFETY: i < len is guaranteed by collect_bool - // If found in set -> false, otherwise true (null handled by validity) - v.is_valid(i) && !self.values.contains(unsafe { &v.value_unchecked(i) }) - }); - let nulls_buf = BooleanBuffer::collect_bool(len, |i| { - // SAFETY: i < len is guaranteed by collect_bool - // Valid (not null) if: needle is valid AND (found OR haystack has no nulls) - v.is_valid(i) && (self.values.contains(unsafe { &v.value_unchecked(i) }) || !haystack_has_nulls) - }); - BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) - } - (false, false) => { - // No nulls anywhere, not negated - let values = v.values(); - let values_buf = BooleanBuffer::collect_bool(values.len(), |i| { - // SAFETY: i < len is guaranteed by collect_bool - self.values.contains(unsafe { values.get_unchecked(i) }) - }); - BooleanArray::new(values_buf, None) - } - (false, true) => { - // No nulls anywhere, negated - let values = v.values(); - let values_buf = BooleanBuffer::collect_bool(values.len(), |i| { - // SAFETY: i < len is guaranteed by collect_bool - !self.values.contains(unsafe { values.get_unchecked(i) }) - }); - BooleanArray::new(values_buf, None) - } - }; - Ok(result) +// Convenience macro for integer types (derives SortedLookup from ArrowType). +macro_rules! sorted_static_filter { + ($Name:ident, $ArrowType:ty) => { + sorted_static_filter_impl!( + $Name, + $ArrowType, + SortedLookup<<$ArrowType as ArrowPrimitiveType>::Native> + ); + }; +} + +// Macro to generate hashed StaticFilter for primitive types using HashedLookup. +macro_rules! hashed_static_filter { + ($Name:ident, $ArrowType:ty) => { + struct $Name { + null_count: usize, + values: HashedLookup<<$ArrowType as ArrowPrimitiveType>::Native>, + } + + impl $Name { + fn try_new(in_array: &ArrayRef) -> Result { + let in_array = + in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let null_count = in_array.null_count(); + let values = HashedLookup::new(in_array.iter().flatten().collect()); + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + // SAFETY: i < len is guaranteed by build_in_list_result + |i| self.values.contains(unsafe { values.get_unchecked(i) }), + )) } } }; } -// Generate specialized filters for all integer primitive types -// Note: Float32 and Float64 are excluded because they don't implement Hash/Eq due to NaN -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); -primitive_static_filter!(Int32StaticFilter, Int32Type); -primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); -primitive_static_filter!(UInt32StaticFilter, UInt32Type); -primitive_static_filter!(UInt64StaticFilter, UInt64Type); +// Generate specialized filters for integer types (sorted for small lists, hashed for large). +sorted_static_filter!(Int8SortedFilter, Int8Type); +sorted_static_filter!(Int16SortedFilter, Int16Type); +sorted_static_filter!(Int32SortedFilter, Int32Type); +sorted_static_filter!(Int64SortedFilter, Int64Type); +sorted_static_filter!(UInt8SortedFilter, UInt8Type); +sorted_static_filter!(UInt16SortedFilter, UInt16Type); +sorted_static_filter!(UInt32SortedFilter, UInt32Type); +sorted_static_filter!(UInt64SortedFilter, UInt64Type); + +hashed_static_filter!(Int8HashedFilter, Int8Type); +hashed_static_filter!(Int16HashedFilter, Int16Type); +hashed_static_filter!(Int32HashedFilter, Int32Type); +hashed_static_filter!(Int64HashedFilter, Int64Type); +hashed_static_filter!(UInt8HashedFilter, UInt8Type); +hashed_static_filter!(UInt16HashedFilter, UInt16Type); +hashed_static_filter!(UInt32HashedFilter, UInt32Type); +hashed_static_filter!(UInt64HashedFilter, UInt64Type); + +// Float types: sorted only (floats don't implement Hash/Eq due to NaN). +sorted_static_filter_impl!(Float32SortedFilter, Float32Type, F32SortedLookup); +sorted_static_filter_impl!(Float64SortedFilter, Float64Type, F64SortedLookup); + +/// Maximum length for inline strings in Utf8View. +/// Strings ≤12 bytes are stored entirely inline in the u128 view. +const UTF8VIEW_INLINE_LEN: usize = 12; + +/// Extract string length from a StringView u128 representation +/// Layout: bytes 0-3 = length (u32 little-endian), bytes 4-15 = inline data +#[inline] +fn view_len(view: u128) -> usize { + (view as u32) as usize +} + +/// Returns (null_count, views) if all non-null strings are ≤12 bytes, otherwise None. +fn collect_short_string_views(in_array: &ArrayRef) -> Option<(usize, Vec)> { + let in_array = in_array.as_string_view_opt()?; + let raw_views = in_array.views(); + + // Check that all non-null strings are ≤12 bytes (inline) + for i in 0..in_array.len() { + if in_array.is_valid(i) && view_len(raw_views[i]) > UTF8VIEW_INLINE_LEN { + return None; // Has long strings, use generic filter + } + } + + let views: Vec = (0..in_array.len()) + .filter(|&i| in_array.is_valid(i)) + .map(|i| raw_views[i]) + .collect(); + + Some((in_array.null_count(), views)) +} + +/// Sorted filter for Utf8View when all values are short (≤12 bytes inline). +/// Uses binary search over sorted raw u128 views. Best for small lists. +struct Utf8ViewSortedFilter { + null_count: usize, + values: SortedLookup, +} + +impl Utf8ViewSortedFilter { + fn try_new(in_array: &ArrayRef) -> Option> { + let (null_count, views) = collect_short_string_views(in_array)?; + Some(Ok(Self { + null_count, + values: SortedLookup::new(views), + })) + } +} + +impl StaticFilter for Utf8ViewSortedFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let v = v.as_string_view_opt().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to StringViewArray") + })?; + + let views = v.views(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + |i| self.values.contains(&views[i]), + )) + } +} + +/// Hashed filter for Utf8View when all values are short (≤12 bytes inline). +/// Uses hash lookup over u128 views. Best for large lists. +struct Utf8ViewHashedFilter { + null_count: usize, + values: HashedLookup, +} + +impl Utf8ViewHashedFilter { + fn try_new(in_array: &ArrayRef) -> Option> { + let (null_count, views) = collect_short_string_views(in_array)?; + Some(Ok(Self { + null_count, + values: HashedLookup::new(views), + })) + } +} + +impl StaticFilter for Utf8ViewHashedFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let v = v.as_string_view_opt().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to StringViewArray") + })?; + + let views = v.views(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + |i| self.values.contains(&views[i]), + )) + } +} /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list(