From d209b3c8d59cd06e6a9ad22ee448bec6a77e7556 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Tue, 6 Jan 2026 17:23:14 +0000 Subject: [PATCH 1/2] feat: imple densitysketch Signed-off-by: Chojan Shang --- datasketches/src/density/mod.rs | 41 ++ datasketches/src/density/serialization.rs | 22 + datasketches/src/density/sketch.rs | 551 ++++++++++++++++++++++ datasketches/src/lib.rs | 1 + datasketches/tests/density_sketch_test.rs | 253 ++++++++++ 5 files changed, 868 insertions(+) create mode 100644 datasketches/src/density/mod.rs create mode 100644 datasketches/src/density/serialization.rs create mode 100644 datasketches/src/density/sketch.rs create mode 100644 datasketches/tests/density_sketch_test.rs diff --git a/datasketches/src/density/mod.rs b/datasketches/src/density/mod.rs new file mode 100644 index 0000000..029fe44 --- /dev/null +++ b/datasketches/src/density/mod.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Density sketch implementation for density estimation from streaming data. +//! +//! The sketch maintains a coreset of points using a compaction scheme and +//! provides density estimates at query points via a kernel function. +//! +//! # Usage +//! +//! ```rust +//! # use datasketches::density::DensitySketch; +//! let mut sketch: DensitySketch = DensitySketch::new(10, 3); +//! sketch.update(vec![0.0, 0.0, 0.0]); +//! sketch.update(vec![1.0, 2.0, 3.0]); +//! let estimate = sketch.estimate(&[0.0, 0.0, 0.0]); +//! assert!(estimate > 0.0); +//! ``` + +mod serialization; +mod sketch; + +pub use self::sketch::DensityItem; +pub use self::sketch::DensityKernel; +pub use self::sketch::DensitySketch; +pub use self::sketch::DensityValue; +pub use self::sketch::GaussianKernel; diff --git a/datasketches/src/density/serialization.rs b/datasketches/src/density/serialization.rs new file mode 100644 index 0000000..22e41b6 --- /dev/null +++ b/datasketches/src/density/serialization.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(super) const PREAMBLE_INTS_SHORT: u8 = 3; +pub(super) const PREAMBLE_INTS_LONG: u8 = 6; +pub(super) const SERIAL_VERSION: u8 = 1; +pub(super) const DENSITY_FAMILY_ID: u8 = 19; +pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 2; diff --git a/datasketches/src/density/sketch.rs b/datasketches/src/density/sketch.rs new file mode 100644 index 0000000..273b795 --- /dev/null +++ b/datasketches/src/density/sketch.rs @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cell::Cell; +use std::io::Read; +use std::io::Write; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use crate::codec::SketchBytes; +use crate::codec::SketchSlice; +use crate::density::serialization::DENSITY_FAMILY_ID; +use crate::density::serialization::FLAGS_IS_EMPTY; +use crate::density::serialization::PREAMBLE_INTS_LONG; +use crate::density::serialization::PREAMBLE_INTS_SHORT; +use crate::density::serialization::SERIAL_VERSION; +use crate::error::Error; +use crate::error::ErrorKind; + +/// Floating point types supported by the density sketch. +pub trait DensityValue: Copy + PartialOrd + 'static { + /// Converts from f64. + fn from_f64(value: f64) -> Self; + /// Converts to f64 for accumulation. + fn to_f64(self) -> f64; +} + +impl DensityValue for f64 { + fn from_f64(value: f64) -> Self { + value + } + + fn to_f64(self) -> f64 { + self + } +} + +impl DensityValue for f32 { + fn from_f64(value: f64) -> Self { + value as f32 + } + + fn to_f64(self) -> f64 { + self as f64 + } +} + +/// Kernel used to compute density contributions between points. +pub trait DensityKernel { + /// Returns the kernel evaluation for the two points. + fn evaluate(&self, left: &[T], right: &[T]) -> T; +} + +/// Gaussian kernel based on squared Euclidean distance. +#[derive(Debug, Default, Clone, Copy)] +pub struct GaussianKernel; + +impl DensityKernel for GaussianKernel { + fn evaluate(&self, left: &[T], right: &[T]) -> T { + let mut sum = 0.0f64; + for (a, b) in left.iter().zip(right.iter()) { + let diff = a.to_f64() - b.to_f64(); + sum += diff * diff; + } + T::from_f64((-sum).exp()) + } +} + +/// Density sketch for streaming density estimation. +pub struct DensitySketch { + kernel: Box>, + k: u16, + dim: u32, + num_retained: u32, + n: u64, + levels: Vec>>, +} + +impl DensitySketch { + /// Creates a new sketch using the Gaussian kernel. + /// + /// # Panics + /// + /// Panics if `k` is less than 2. + pub fn new(k: u16, dim: u32) -> Self { + Self::with_kernel(k, dim, Box::new(GaussianKernel)) + } + + /// Creates a new sketch with a custom kernel. + /// + /// # Panics + /// + /// Panics if `k` is less than 2. + pub fn with_kernel(k: u16, dim: u32, kernel: Box>) -> Self { + check_k(k); + Self { + kernel, + k, + dim, + num_retained: 0, + n: 0, + levels: vec![Vec::new()], + } + } + + /// Deserializes a sketch using the Gaussian kernel. + pub fn deserialize(bytes: &[u8]) -> Result { + Self::deserialize_with_kernel(bytes, Box::new(GaussianKernel)) + } + + /// Deserializes a sketch using the provided kernel. + pub fn deserialize_with_kernel( + bytes: &[u8], + kernel: Box>, + ) -> Result { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { + move |_| Error::insufficient_data(tag) + } + + let mut cursor = SketchSlice::new(bytes); + let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let k = cursor.read_u16_le().map_err(make_error("k"))?; + cursor.read_u16_le().map_err(make_error("unused"))?; + let dim = cursor.read_u32_le().map_err(make_error("dim"))?; + + if family_id != DENSITY_FAMILY_ID { + return Err(Error::invalid_family( + DENSITY_FAMILY_ID, + family_id, + "DensitySketch", + )); + } + if serial_version != SERIAL_VERSION { + return Err(Error::unsupported_serial_version( + SERIAL_VERSION, + serial_version, + )); + } + validate_k(k)?; + check_header_validity(preamble_ints, flags)?; + + let is_empty = (flags & FLAGS_IS_EMPTY) != 0; + if is_empty { + return Ok(Self::with_kernel(k, dim, kernel)); + } + + let num_retained = cursor.read_u32_le().map_err(make_error("num_retained"))?; + let n = cursor.read_u64_le().map_err(make_error("n"))?; + + let mut levels = Vec::new(); + let mut remaining = num_retained as i64; + while remaining > 0 { + let level_size = cursor.read_u32_le().map_err(make_error("level_size"))?; + let mut level = Vec::with_capacity(level_size as usize); + for _ in 0..level_size { + let mut point = Vec::with_capacity(dim as usize); + for _ in 0..dim { + point.push(read_value(&mut cursor).map_err(make_error("point"))?); + } + level.push(point); + } + remaining -= level_size as i64; + levels.push(level); + } + if remaining != 0 { + return Err(Error::deserial( + "invalid number of retained points while decoding density sketch", + )); + } + + Ok(Self { + kernel, + k, + dim, + num_retained, + n, + levels, + }) + } + + /// Deserializes a sketch from a reader using the Gaussian kernel. + pub fn deserialize_from_reader(reader: &mut dyn Read) -> Result { + Self::deserialize_from_reader_with_kernel(reader, Box::new(GaussianKernel)) + } + + /// Deserializes a sketch from a reader using the provided kernel. + pub fn deserialize_from_reader_with_kernel( + reader: &mut dyn Read, + kernel: Box>, + ) -> Result { + let mut buf = Vec::new(); + reader + .read_to_end(&mut buf) + .map_err(|err| Error::deserial(format!("error reading stream: {err}")))?; + Self::deserialize_with_kernel(&buf, kernel) + } + + /// Returns the configured parameter k. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns the configured dimension. + pub fn dim(&self) -> u32 { + self.dim + } + + /// Returns true if the sketch is empty. + pub fn is_empty(&self) -> bool { + self.num_retained == 0 + } + + /// Returns the number of points observed by this sketch. + pub fn n(&self) -> u64 { + self.n + } + + /// Returns the number of retained points. + pub fn num_retained(&self) -> u32 { + self.num_retained + } + + /// Returns true if the sketch is in estimation mode. + pub fn is_estimation_mode(&self) -> bool { + self.levels.len() > 1 + } + + /// Updates this sketch with a given point. + /// + /// # Panics + /// + /// Panics if the point dimension does not match this sketch. + pub fn update(&mut self, point: Vec) { + if point.len() != self.dim as usize { + panic!("dimension mismatch"); + } + while self.num_retained >= self.k as u32 * self.levels.len() as u32 { + self.compact(); + } + self.levels[0].push(point); + self.num_retained += 1; + self.n += 1; + } + + /// Updates this sketch with a slice, copying the point into the sketch. + /// + /// # Panics + /// + /// Panics if the point dimension does not match this sketch. + pub fn update_slice(&mut self, point: &[T]) { + self.update(point.to_vec()); + } + + /// Merges another sketch into this one. + /// + /// # Panics + /// + /// Panics if dimensions do not match. + pub fn merge(&mut self, other: &Self) { + if other.is_empty() { + return; + } + if other.dim != self.dim { + panic!("dimension mismatch"); + } + while self.levels.len() < other.levels.len() { + self.levels.push(Vec::new()); + } + for (height, level) in other.levels.iter().enumerate() { + self.levels[height].extend(level.iter().cloned()); + } + self.num_retained += other.num_retained; + self.n += other.n; + while self.num_retained >= self.k as u32 * self.levels.len() as u32 { + self.compact(); + } + } + + /// Returns a density estimate at a given point. + /// + /// # Panics + /// + /// Panics if the sketch is empty. + pub fn estimate(&self, point: &[T]) -> T { + if self.is_empty() { + panic!("operation is undefined for an empty sketch"); + } + let n = self.n as f64; + let mut density = 0.0f64; + for (height, level) in self.levels.iter().enumerate() { + let weight = match height { + 0..=127 => 1u128 << height, + _ => panic!("level height too large"), + }; + let height_weight = weight as f64; + for p in level { + density += height_weight * self.kernel.evaluate(p, point).to_f64() / n; + } + } + T::from_f64(density) + } + + /// Serializes the sketch to a byte vector. + pub fn serialize(&self) -> Vec { + let preamble_ints = if self.is_empty() { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_LONG + }; + let mut size_bytes = preamble_ints as usize * 4; + if !self.is_empty() { + for level in &self.levels { + size_bytes += 4 + (level.len() * self.dim as usize * std::mem::size_of::()); + } + } + let mut bytes = SketchBytes::with_capacity(size_bytes); + bytes.write_u8(preamble_ints); + bytes.write_u8(SERIAL_VERSION); + bytes.write_u8(DENSITY_FAMILY_ID); + let flags = if self.is_empty() { FLAGS_IS_EMPTY } else { 0 }; + bytes.write_u8(flags); + bytes.write_u16_le(self.k); + bytes.write_u16_le(0); + bytes.write_u32_le(self.dim); + + if self.is_empty() { + return bytes.into_bytes(); + } + + bytes.write_u32_le(self.num_retained); + bytes.write_u64_le(self.n); + for level in &self.levels { + bytes.write_u32_le(level.len() as u32); + for point in level { + for value in point { + write_value(&mut bytes, *value); + } + } + } + bytes.into_bytes() + } + + /// Serializes the sketch to a writer. + pub fn serialize_to_writer(&self, writer: &mut dyn Write) -> std::io::Result<()> { + writer.write_all(&self.serialize()) + } + + /// Returns an iterator over retained points with their weights. + pub fn iter(&self) -> DensityIter<'_, T> { + DensityIter { + levels: &self.levels, + level_index: 0, + item_index: 0, + } + } + + fn compact(&mut self) { + for height in 0..self.levels.len() { + if self.levels[height].len() >= self.k as usize { + if height + 1 >= self.levels.len() { + self.levels.push(Vec::new()); + } + self.compact_level(height); + break; + } + } + } + + fn compact_level(&mut self, height: usize) { + let level_len = self.levels[height].len(); + if level_len == 0 { + return; + } + shuffle(&mut self.levels[height]); + let mut bits = vec![false; level_len]; + bits[0] = random_bit(); + for i in 1..level_len { + let mut delta = 0.0f64; + for (j, bit) in bits.iter().enumerate().take(i) { + let weight = if *bit { 1.0 } else { -1.0 }; + delta += weight + * self + .kernel + .evaluate(&self.levels[height][i], &self.levels[height][j]) + .to_f64(); + } + bits[i] = delta < 0.0; + } + let old_level = std::mem::take(&mut self.levels[height]); + for (index, point) in old_level.into_iter().enumerate() { + if bits[index] { + self.levels[height + 1].push(point); + } else { + self.num_retained -= 1; + } + } + } +} + +/// Borrowed view of a retained point and its weight. +pub struct DensityItem<'a, T> { + /// The retained point. + pub point: &'a [T], + /// The weight associated with the point. + pub weight: u64, +} + +/// Iterator over retained points and their weights. +pub struct DensityIter<'a, T> { + levels: &'a [Vec>], + level_index: usize, + item_index: usize, +} + +impl<'a, T> Iterator for DensityIter<'a, T> { + type Item = DensityItem<'a, T>; + + fn next(&mut self) -> Option { + while self.level_index < self.levels.len() { + let level = &self.levels[self.level_index]; + if self.item_index < level.len() { + let weight = match self.level_index { + 0..=63 => 1u64 << self.level_index, + _ => panic!("level height too large"), + }; + let item = DensityItem { + point: &level[self.item_index], + weight, + }; + self.item_index += 1; + return Some(item); + } + self.level_index += 1; + self.item_index = 0; + } + None + } +} + +impl<'a, T: DensityValue> IntoIterator for &'a DensitySketch { + type Item = DensityItem<'a, T>; + type IntoIter = DensityIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +fn check_k(k: u16) { + assert!(k >= 2, "k must be > 1. Found: {k}"); +} + +fn validate_k(k: u16) -> Result<(), Error> { + if k >= 2 { + Ok(()) + } else { + Err(Error::new( + ErrorKind::InvalidArgument, + format!("k must be > 1. Found: {k}"), + )) + } +} + +fn check_header_validity(preamble_ints: u8, flags: u8) -> Result<(), Error> { + let empty = (flags & FLAGS_IS_EMPTY) != 0; + if (empty && preamble_ints == PREAMBLE_INTS_SHORT) + || (!empty && preamble_ints == PREAMBLE_INTS_LONG) + { + return Ok(()); + } + let expected = if empty { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_LONG + }; + Err(Error::invalid_preamble_longs(expected, preamble_ints)) +} + +fn write_value(bytes: &mut SketchBytes, value: T) { + if std::mem::size_of::() == 4 { + bytes.write_f32_le(value.to_f64() as f32); + } else { + bytes.write_f64_le(value.to_f64()); + } +} + +fn read_value(cursor: &mut SketchSlice<'_>) -> std::io::Result { + if std::mem::size_of::() == 4 { + cursor.read_f32_le().map(|v| T::from_f64(v as f64)) + } else { + cursor.read_f64_le().map(T::from_f64) + } +} + +thread_local! { + static RNG_STATE: Cell = Cell::new(seed_rng()); +} + +fn seed_rng() -> u64 { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let mut seed = nanos as u64 ^ (std::process::id() as u64); + if seed == 0 { + seed = 0x9e3779b97f4a7c15; + } + seed +} + +fn next_u64() -> u64 { + RNG_STATE.with(|state| { + let mut x = state.get(); + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + state.set(x); + x + }) +} + +fn random_bit() -> bool { + (next_u64() & 1) != 0 +} + +fn shuffle(slice: &mut [T]) { + if slice.len() <= 1 { + return; + } + for i in (1..slice.len()).rev() { + let j = (next_u64() % (i as u64 + 1)) as usize; + slice.swap(i, j); + } +} diff --git a/datasketches/src/lib.rs b/datasketches/src/lib.rs index 473ab9b..5f1ea6e 100644 --- a/datasketches/src/lib.rs +++ b/datasketches/src/lib.rs @@ -32,6 +32,7 @@ compile_error!("datasketches does not support big-endian targets"); pub mod bloom; pub mod countmin; +pub mod density; pub mod error; pub mod frequencies; pub mod hll; diff --git a/datasketches/tests/density_sketch_test.rs b/datasketches/tests/density_sketch_test.rs new file mode 100644 index 0000000..53586cc --- /dev/null +++ b/datasketches/tests/density_sketch_test.rs @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Cursor; + +use datasketches::density::DensityKernel; +use datasketches::density::DensitySketch; + +#[test] +#[should_panic(expected = "operation is undefined for an empty sketch")] +fn test_empty() { + let sketch: DensitySketch = DensitySketch::new(10, 3); + assert!(sketch.is_empty()); + let _ = sketch.estimate(&[0.0, 0.0, 0.0]); +} + +#[test] +#[should_panic(expected = "dimension mismatch")] +fn test_dimension_mismatch() { + let mut sketch: DensitySketch = DensitySketch::new(10, 3); + sketch.update(vec![0.0, 0.0]); +} + +#[test] +fn test_one_item() { + let mut sketch: DensitySketch = DensitySketch::new(10, 3); + + sketch.update(vec![0.0, 0.0, 0.0]); + assert!(!sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.estimate(&[0.0, 0.0, 0.0]), 1.0); + assert!(sketch.estimate(&[0.01, 0.01, 0.01]) > 0.95); + assert!(sketch.estimate(&[1.0, 1.0, 1.0]) < 0.05); +} + +#[test] +fn test_merge() { + let mut sketch1: DensitySketch = DensitySketch::new(10, 4); + sketch1.update(vec![0.0, 0.0, 0.0, 0.0]); + sketch1.update(vec![1.0, 2.0, 3.0, 4.0]); + + let mut sketch2: DensitySketch = DensitySketch::new(10, 4); + sketch2.update(vec![5.0, 6.0, 7.0, 8.0]); + + sketch1.merge(&sketch2); + assert_eq!(sketch1.n(), 3); + assert_eq!(sketch1.num_retained(), 3); +} + +#[test] +fn test_iterator() { + let mut sketch: DensitySketch = DensitySketch::new(10, 3); + let n = 1000; + for i in 1..=n { + sketch.update(vec![i as f32, i as f32, i as f32]); + } + assert_eq!(sketch.n(), n as u64); + assert!(sketch.is_estimation_mode()); + + let mut count = 0; + for item in &sketch { + count += 1; + assert_eq!(item.point.len(), sketch.dim() as usize); + } + assert_eq!(count as u32, sketch.num_retained()); +} + +#[derive(Clone, Copy)] +struct SphericalKernel { + radius_squared: f32, +} + +impl DensityKernel for SphericalKernel { + fn evaluate(&self, left: &[f32], right: &[f32]) -> f32 { + let mut sum = 0.0f32; + for (a, b) in left.iter().zip(right.iter()) { + let diff = a - b; + sum += diff * diff; + } + if sum <= self.radius_squared { 1.0 } else { 0.0 } + } +} + +#[test] +fn test_custom_kernel() { + let kernel = SphericalKernel { + radius_squared: 0.25, + }; + let mut sketch: DensitySketch = DensitySketch::with_kernel(10, 3, Box::new(kernel)); + + sketch.update(vec![1.0, 1.0, 1.0]); + assert_eq!(sketch.estimate(&[1.001, 1.001, 1.001]), 1.0); + assert_eq!(sketch.estimate(&[2.0, 2.0, 2.0]), 0.0); + + let n = 1000; + for i in 2..=n { + sketch.update(vec![i as f32, i as f32, i as f32]); + } + assert_eq!(sketch.n(), n as u64); + assert!(sketch.is_estimation_mode()); + let mut count = 0; + for item in &sketch { + count += 1; + assert_eq!(item.point.len(), sketch.dim() as usize); + } + assert_eq!(count as u32, sketch.num_retained()); +} + +#[test] +fn test_serialize_empty() { + let sketch: DensitySketch = DensitySketch::new(10, 2); + let bytes = sketch.serialize(); + let decoded = DensitySketch::::deserialize(&bytes).unwrap(); + assert!(decoded.is_empty()); + assert!(!decoded.is_estimation_mode()); + assert_eq!(sketch.k(), decoded.k()); + assert_eq!(sketch.dim(), decoded.dim()); + assert_eq!(sketch.n(), decoded.n()); + assert_eq!(sketch.num_retained(), decoded.num_retained()); + + let mut cursor = Cursor::new(Vec::new()); + sketch.serialize_to_writer(&mut cursor).unwrap(); + cursor.set_position(0); + let decoded = DensitySketch::::deserialize_from_reader(&mut cursor).unwrap(); + assert!(decoded.is_empty()); + assert!(!decoded.is_estimation_mode()); + assert_eq!(sketch.k(), decoded.k()); + assert_eq!(sketch.dim(), decoded.dim()); + assert_eq!(sketch.n(), decoded.n()); + assert_eq!(sketch.num_retained(), decoded.num_retained()); +} + +#[test] +fn test_serialize_bytes() { + let k = 10; + let dim = 3; + let mut sketch: DensitySketch = DensitySketch::new(k, dim); + + for i in 0..k { + let value = i as f64; + sketch.update(vec![value, value.sqrt(), -value]); + } + assert!(!sketch.is_estimation_mode()); + + let bytes = sketch.serialize(); + let decoded = DensitySketch::::deserialize(&bytes).unwrap(); + assert!(!decoded.is_empty()); + assert!(!decoded.is_estimation_mode()); + assert_eq!(sketch.k(), decoded.k()); + assert_eq!(sketch.dim(), decoded.dim()); + assert_eq!(sketch.n(), decoded.n()); + assert_eq!(sketch.num_retained(), decoded.num_retained()); + let mut iter_left = sketch.iter(); + let mut iter_right = decoded.iter(); + while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { + assert_eq!(left.point[0], right.point[0]); + assert_eq!(left.weight, right.weight); + } + + let n = 1031; + for i in k..n { + let value = i as f64; + sketch.update(vec![value, value.sqrt(), -value]); + } + assert!(sketch.is_estimation_mode()); + + let bytes = sketch.serialize(); + let decoded = DensitySketch::::deserialize(&bytes).unwrap(); + assert!(!decoded.is_empty()); + assert!(decoded.is_estimation_mode()); + assert_eq!(sketch.k(), decoded.k()); + assert_eq!(sketch.dim(), decoded.dim()); + assert_eq!(sketch.n(), decoded.n()); + assert_eq!(sketch.num_retained(), decoded.num_retained()); + let mut iter_left = sketch.iter(); + let mut iter_right = decoded.iter(); + while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { + assert_eq!(left.point[0], right.point[0]); + assert_eq!(left.weight, right.weight); + } +} + +#[test] +fn test_serialize_stream() { + let k = 10; + let dim = 3; + let mut sketch: DensitySketch = DensitySketch::new(k, dim); + + for i in 0..k { + let value = i as f32; + sketch.update(vec![value, value.sin(), value.cos()]); + } + assert!(!sketch.is_estimation_mode()); + + let mut cursor = Cursor::new(Vec::new()); + sketch.serialize_to_writer(&mut cursor).unwrap(); + cursor.set_position(0); + let decoded = DensitySketch::::deserialize_from_reader(&mut cursor).unwrap(); + assert!(!decoded.is_empty()); + assert!(!decoded.is_estimation_mode()); + assert_eq!(sketch.k(), decoded.k()); + assert_eq!(sketch.dim(), decoded.dim()); + assert_eq!(sketch.n(), decoded.n()); + assert_eq!(sketch.num_retained(), decoded.num_retained()); + let mut iter_left = sketch.iter(); + let mut iter_right = decoded.iter(); + while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { + assert_eq!(left.point[0], right.point[0]); + assert_eq!(left.weight, right.weight); + assert_eq!(left.point[1], right.point[1]); + assert_eq!(left.point[2], right.point[2]); + } + + let n = 1031; + for i in k..n { + let value = i as f32; + sketch.update(vec![value, value.sqrt(), -value]); + } + assert!(sketch.is_estimation_mode()); + + let mut cursor = Cursor::new(Vec::new()); + sketch.serialize_to_writer(&mut cursor).unwrap(); + cursor.set_position(0); + let decoded = DensitySketch::::deserialize_from_reader(&mut cursor).unwrap(); + assert!(!decoded.is_empty()); + assert!(decoded.is_estimation_mode()); + assert_eq!(sketch.k(), decoded.k()); + assert_eq!(sketch.dim(), decoded.dim()); + assert_eq!(sketch.n(), decoded.n()); + assert_eq!(sketch.num_retained(), decoded.num_retained()); + let mut iter_left = sketch.iter(); + let mut iter_right = decoded.iter(); + while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { + assert_eq!(left.point[0], right.point[0]); + assert_eq!(left.weight, right.weight); + assert_eq!(left.point[1], right.point[1]); + assert_eq!(left.point[2], right.point[2]); + } +} From 10c8526d69dacf25a61a249f76c2c9911d18285b Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 18 Jan 2026 19:17:43 +0000 Subject: [PATCH 2/2] refactor: align density sketch kernel rng --- datasketches/src/common/mod.rs | 23 + datasketches/src/common/random.rs | 71 +++ datasketches/src/density/sketch.rs | 517 ++++++++++++---------- datasketches/src/lib.rs | 1 + datasketches/tests/density_sketch_test.rs | 57 +-- 5 files changed, 401 insertions(+), 268 deletions(-) create mode 100644 datasketches/src/common/mod.rs create mode 100644 datasketches/src/common/random.rs diff --git a/datasketches/src/common/mod.rs b/datasketches/src/common/mod.rs new file mode 100644 index 0000000..35d0df1 --- /dev/null +++ b/datasketches/src/common/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared utilities across sketches. + +pub mod random; + +pub use self::random::RandomSource; +pub use self::random::XorShift64; diff --git a/datasketches/src/common/random.rs b/datasketches/src/common/random.rs new file mode 100644 index 0000000..b0630c8 --- /dev/null +++ b/datasketches/src/common/random.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared random utilities for sketches. + +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +/// Random number source for sketches. +pub trait RandomSource { + /// Returns the next random 64-bit value. + fn next_u64(&mut self) -> u64; + + /// Returns a random boolean value. + fn next_bool(&mut self) -> bool { + (self.next_u64() & 1) != 0 + } +} + +/// Xorshift-based random generator for sketch operations. +#[derive(Debug, Clone, Copy)] +pub struct XorShift64 { + state: u64, +} + +impl XorShift64 { + /// Creates a new generator using the provided seed. + pub fn seeded(seed: u64) -> Self { + let state = if seed == 0 { 0x9e3779b97f4a7c15 } else { seed }; + Self { state } + } +} + +impl Default for XorShift64 { + fn default() -> Self { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let mut seed = nanos as u64 ^ (std::process::id() as u64); + if seed == 0 { + seed = 0x9e3779b97f4a7c15; + } + Self::seeded(seed) + } +} + +impl RandomSource for XorShift64 { + fn next_u64(&mut self) -> u64 { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.state = x; + x + } +} diff --git a/datasketches/src/density/sketch.rs b/datasketches/src/density/sketch.rs index 273b795..8625d8b 100644 --- a/datasketches/src/density/sketch.rs +++ b/datasketches/src/density/sketch.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::cell::Cell; -use std::io::Read; use std::io::Write; -use std::time::SystemTime; -use std::time::UNIX_EPOCH; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::common::RandomSource; +use crate::common::XorShift64; use crate::density::serialization::DENSITY_FAMILY_ID; use crate::density::serialization::FLAGS_IS_EMPTY; use crate::density::serialization::PREAMBLE_INTS_LONG; @@ -31,6 +29,9 @@ use crate::density::serialization::SERIAL_VERSION; use crate::error::Error; use crate::error::ErrorKind; +type SerializeValue = fn(&mut SketchBytes, T); +type DeserializeValue = fn(&mut SketchSlice<'_>) -> std::io::Result; + /// Floating point types supported by the density sketch. pub trait DensityValue: Copy + PartialOrd + 'static { /// Converts from f64. @@ -60,17 +61,17 @@ impl DensityValue for f32 { } /// Kernel used to compute density contributions between points. -pub trait DensityKernel { +pub trait DensityKernel { /// Returns the kernel evaluation for the two points. - fn evaluate(&self, left: &[T], right: &[T]) -> T; + fn evaluate(&self, left: &[T], right: &[T]) -> T; } /// Gaussian kernel based on squared Euclidean distance. #[derive(Debug, Default, Clone, Copy)] pub struct GaussianKernel; -impl DensityKernel for GaussianKernel { - fn evaluate(&self, left: &[T], right: &[T]) -> T { +impl DensityKernel for GaussianKernel { + fn evaluate(&self, left: &[T], right: &[T]) -> T { let mut sum = 0.0f64; for (a, b) in left.iter().zip(right.iter()) { let diff = a.to_f64() - b.to_f64(); @@ -81,8 +82,13 @@ impl DensityKernel for GaussianKernel { } /// Density sketch for streaming density estimation. -pub struct DensitySketch { - kernel: Box>, +pub struct DensitySketch< + T: DensityValue, + K: DensityKernel = GaussianKernel, + R: RandomSource = XorShift64, +> { + kernel: K, + rng: R, k: u16, dim: u32, num_retained: u32, @@ -90,126 +96,107 @@ pub struct DensitySketch { levels: Vec>>, } -impl DensitySketch { +impl DensitySketch { /// Creates a new sketch using the Gaussian kernel. /// /// # Panics /// /// Panics if `k` is less than 2. pub fn new(k: u16, dim: u32) -> Self { - Self::with_kernel(k, dim, Box::new(GaussianKernel)) + Self::with_kernel(k, dim, GaussianKernel) } +} - /// Creates a new sketch with a custom kernel. - /// - /// # Panics - /// - /// Panics if `k` is less than 2. - pub fn with_kernel(k: u16, dim: u32, kernel: Box>) -> Self { - check_k(k); - Self { - kernel, - k, - dim, - num_retained: 0, - n: 0, - levels: vec![Vec::new()], - } +impl DensitySketch { + /// Deserializes a sketch using the Gaussian kernel. + pub fn deserialize(bytes: &[u8]) -> Result { + Self::deserialize_with_kernel(bytes, GaussianKernel) } +} +impl DensitySketch { /// Deserializes a sketch using the Gaussian kernel. pub fn deserialize(bytes: &[u8]) -> Result { - Self::deserialize_with_kernel(bytes, Box::new(GaussianKernel)) + Self::deserialize_with_kernel(bytes, GaussianKernel) } +} +impl DensitySketch { /// Deserializes a sketch using the provided kernel. - pub fn deserialize_with_kernel( - bytes: &[u8], - kernel: Box>, - ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } + pub fn deserialize_with_kernel(bytes: &[u8], kernel: K) -> Result { + Self::deserialize_with_kernel_and_rng(bytes, kernel, XorShift64::default()) + } +} - let mut cursor = SketchSlice::new(bytes); - let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - let k = cursor.read_u16_le().map_err(make_error("k"))?; - cursor.read_u16_le().map_err(make_error("unused"))?; - let dim = cursor.read_u32_le().map_err(make_error("dim"))?; - - if family_id != DENSITY_FAMILY_ID { - return Err(Error::invalid_family( - DENSITY_FAMILY_ID, - family_id, - "DensitySketch", - )); - } - if serial_version != SERIAL_VERSION { - return Err(Error::unsupported_serial_version( - SERIAL_VERSION, - serial_version, - )); - } - validate_k(k)?; - check_header_validity(preamble_ints, flags)?; +impl DensitySketch { + /// Deserializes a sketch using the provided kernel. + pub fn deserialize_with_kernel(bytes: &[u8], kernel: K) -> Result { + Self::deserialize_with_kernel_and_rng(bytes, kernel, XorShift64::default()) + } +} - let is_empty = (flags & FLAGS_IS_EMPTY) != 0; - if is_empty { - return Ok(Self::with_kernel(k, dim, kernel)); - } +impl DensitySketch { + /// Deserializes a sketch using the provided kernel and random source. + pub fn deserialize_with_kernel_and_rng(bytes: &[u8], kernel: K, rng: R) -> Result { + deserialize_inner(bytes, kernel, rng, read_f32_value) + } - let num_retained = cursor.read_u32_le().map_err(make_error("num_retained"))?; - let n = cursor.read_u64_le().map_err(make_error("n"))?; - - let mut levels = Vec::new(); - let mut remaining = num_retained as i64; - while remaining > 0 { - let level_size = cursor.read_u32_le().map_err(make_error("level_size"))?; - let mut level = Vec::with_capacity(level_size as usize); - for _ in 0..level_size { - let mut point = Vec::with_capacity(dim as usize); - for _ in 0..dim { - point.push(read_value(&mut cursor).map_err(make_error("point"))?); - } - level.push(point); - } - remaining -= level_size as i64; - levels.push(level); - } - if remaining != 0 { - return Err(Error::deserial( - "invalid number of retained points while decoding density sketch", - )); - } + /// Serializes the sketch to a byte vector. + pub fn serialize(&self) -> Vec { + serialize_inner(self, 4, write_f32_value) + } - Ok(Self { - kernel, - k, - dim, - num_retained, - n, - levels, - }) + /// Serializes the sketch to a writer. + pub fn serialize_to_writer(&self, writer: &mut dyn Write) -> std::io::Result<()> { + writer.write_all(&self.serialize()) + } +} + +impl DensitySketch { + /// Deserializes a sketch using the provided kernel and random source. + pub fn deserialize_with_kernel_and_rng(bytes: &[u8], kernel: K, rng: R) -> Result { + deserialize_inner(bytes, kernel, rng, read_f64_value) } - /// Deserializes a sketch from a reader using the Gaussian kernel. - pub fn deserialize_from_reader(reader: &mut dyn Read) -> Result { - Self::deserialize_from_reader_with_kernel(reader, Box::new(GaussianKernel)) + /// Serializes the sketch to a byte vector. + pub fn serialize(&self) -> Vec { + serialize_inner(self, 8, write_f64_value) } - /// Deserializes a sketch from a reader using the provided kernel. - pub fn deserialize_from_reader_with_kernel( - reader: &mut dyn Read, - kernel: Box>, - ) -> Result { - let mut buf = Vec::new(); - reader - .read_to_end(&mut buf) - .map_err(|err| Error::deserial(format!("error reading stream: {err}")))?; - Self::deserialize_with_kernel(&buf, kernel) + /// Serializes the sketch to a writer. + pub fn serialize_to_writer(&self, writer: &mut dyn Write) -> std::io::Result<()> { + writer.write_all(&self.serialize()) + } +} + +impl DensitySketch { + /// Creates a new sketch with a custom kernel. + /// + /// # Panics + /// + /// Panics if `k` is less than 2. + pub fn with_kernel(k: u16, dim: u32, kernel: K) -> Self { + Self::with_kernel_and_rng(k, dim, kernel, XorShift64::default()) + } +} + +impl DensitySketch { + /// Creates a new sketch with a custom kernel and random source. + /// + /// # Panics + /// + /// Panics if `k` is less than 2. + pub fn with_kernel_and_rng(k: u16, dim: u32, kernel: K, rng: R) -> Self { + assert!(k >= 2, "k must be > 1. Found: {k}"); + Self { + kernel, + rng, + k, + dim, + num_retained: 0, + n: 0, + levels: vec![Vec::new()], + } } /// Returns the configured parameter k. @@ -305,11 +292,7 @@ impl DensitySketch { let n = self.n as f64; let mut density = 0.0f64; for (height, level) in self.levels.iter().enumerate() { - let weight = match height { - 0..=127 => 1u128 << height, - _ => panic!("level height too large"), - }; - let height_weight = weight as f64; + let height_weight = weight_for_level(height) as f64; for p in level { density += height_weight * self.kernel.evaluate(p, point).to_f64() / n; } @@ -317,51 +300,6 @@ impl DensitySketch { T::from_f64(density) } - /// Serializes the sketch to a byte vector. - pub fn serialize(&self) -> Vec { - let preamble_ints = if self.is_empty() { - PREAMBLE_INTS_SHORT - } else { - PREAMBLE_INTS_LONG - }; - let mut size_bytes = preamble_ints as usize * 4; - if !self.is_empty() { - for level in &self.levels { - size_bytes += 4 + (level.len() * self.dim as usize * std::mem::size_of::()); - } - } - let mut bytes = SketchBytes::with_capacity(size_bytes); - bytes.write_u8(preamble_ints); - bytes.write_u8(SERIAL_VERSION); - bytes.write_u8(DENSITY_FAMILY_ID); - let flags = if self.is_empty() { FLAGS_IS_EMPTY } else { 0 }; - bytes.write_u8(flags); - bytes.write_u16_le(self.k); - bytes.write_u16_le(0); - bytes.write_u32_le(self.dim); - - if self.is_empty() { - return bytes.into_bytes(); - } - - bytes.write_u32_le(self.num_retained); - bytes.write_u64_le(self.n); - for level in &self.levels { - bytes.write_u32_le(level.len() as u32); - for point in level { - for value in point { - write_value(&mut bytes, *value); - } - } - } - bytes.into_bytes() - } - - /// Serializes the sketch to a writer. - pub fn serialize_to_writer(&self, writer: &mut dyn Write) -> std::io::Result<()> { - writer.write_all(&self.serialize()) - } - /// Returns an iterator over retained points with their weights. pub fn iter(&self) -> DensityIter<'_, T> { DensityIter { @@ -388,20 +326,21 @@ impl DensitySketch { if level_len == 0 { return; } - shuffle(&mut self.levels[height]); let mut bits = vec![false; level_len]; - bits[0] = random_bit(); - for i in 1..level_len { - let mut delta = 0.0f64; - for (j, bit) in bits.iter().enumerate().take(i) { - let weight = if *bit { 1.0 } else { -1.0 }; - delta += weight - * self - .kernel - .evaluate(&self.levels[height][i], &self.levels[height][j]) - .to_f64(); + { + let rng = &mut self.rng; + let level = &mut self.levels[height]; + let kernel = &self.kernel; + shuffle_with_rng(rng, level); + bits[0] = random_bit(rng); + for i in 1..level_len { + let mut delta = 0.0f64; + for (j, bit) in bits.iter().enumerate().take(i) { + let weight = if *bit { 1.0 } else { -1.0 }; + delta += weight * kernel.evaluate(&level[i], &level[j]).to_f64(); + } + bits[i] = delta < 0.0; } - bits[i] = delta < 0.0; } let old_level = std::mem::take(&mut self.levels[height]); for (index, point) in old_level.into_iter().enumerate() { @@ -417,9 +356,21 @@ impl DensitySketch { /// Borrowed view of a retained point and its weight. pub struct DensityItem<'a, T> { /// The retained point. - pub point: &'a [T], + point: &'a [T], /// The weight associated with the point. - pub weight: u64, + weight: u64, +} + +impl<'a, T> DensityItem<'a, T> { + /// Returns the retained point. + pub fn point(&self) -> &'a [T] { + self.point + } + + /// Returns the weight associated with the point. + pub fn weight(&self) -> u64 { + self.weight + } } /// Iterator over retained points and their weights. @@ -436,10 +387,7 @@ impl<'a, T> Iterator for DensityIter<'a, T> { while self.level_index < self.levels.len() { let level = &self.levels[self.level_index]; if self.item_index < level.len() { - let weight = match self.level_index { - 0..=63 => 1u64 << self.level_index, - _ => panic!("level height too large"), - }; + let weight = weight_for_level(self.level_index); let item = DensityItem { point: &level[self.item_index], weight, @@ -454,7 +402,9 @@ impl<'a, T> Iterator for DensityIter<'a, T> { } } -impl<'a, T: DensityValue> IntoIterator for &'a DensitySketch { +impl<'a, T: DensityValue, K: DensityKernel, R: RandomSource> IntoIterator + for &'a DensitySketch +{ type Item = DensityItem<'a, T>; type IntoIter = DensityIter<'a, T>; @@ -463,89 +413,172 @@ impl<'a, T: DensityValue> IntoIterator for &'a DensitySketch { } } -fn check_k(k: u16) { - assert!(k >= 2, "k must be > 1. Found: {k}"); -} - -fn validate_k(k: u16) -> Result<(), Error> { - if k >= 2 { - Ok(()) - } else { - Err(Error::new( - ErrorKind::InvalidArgument, - format!("k must be > 1. Found: {k}"), - )) +fn weight_for_level(level: usize) -> u64 { + match level { + 0..=63 => 1u64 << level, + _ => panic!("level height too large"), } } -fn check_header_validity(preamble_ints: u8, flags: u8) -> Result<(), Error> { - let empty = (flags & FLAGS_IS_EMPTY) != 0; - if (empty && preamble_ints == PREAMBLE_INTS_SHORT) - || (!empty && preamble_ints == PREAMBLE_INTS_LONG) - { - return Ok(()); - } - let expected = if empty { - PREAMBLE_INTS_SHORT - } else { - PREAMBLE_INTS_LONG - }; - Err(Error::invalid_preamble_longs(expected, preamble_ints)) +fn random_bit(rng: &mut R) -> bool { + rng.next_bool() } -fn write_value(bytes: &mut SketchBytes, value: T) { - if std::mem::size_of::() == 4 { - bytes.write_f32_le(value.to_f64() as f32); - } else { - bytes.write_f64_le(value.to_f64()); +fn shuffle_with_rng(rng: &mut R, slice: &mut [T]) { + if slice.len() <= 1 { + return; } -} - -fn read_value(cursor: &mut SketchSlice<'_>) -> std::io::Result { - if std::mem::size_of::() == 4 { - cursor.read_f32_le().map(|v| T::from_f64(v as f64)) - } else { - cursor.read_f64_le().map(T::from_f64) + for i in (1..slice.len()).rev() { + let j = (rng.next_u64() % (i as u64 + 1)) as usize; + slice.swap(i, j); } } -thread_local! { - static RNG_STATE: Cell = Cell::new(seed_rng()); +fn write_f32_value(bytes: &mut SketchBytes, value: f32) { + bytes.write_f32_le(value); } -fn seed_rng() -> u64 { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_nanos(); - let mut seed = nanos as u64 ^ (std::process::id() as u64); - if seed == 0 { - seed = 0x9e3779b97f4a7c15; - } - seed +fn write_f64_value(bytes: &mut SketchBytes, value: f64) { + bytes.write_f64_le(value); } -fn next_u64() -> u64 { - RNG_STATE.with(|state| { - let mut x = state.get(); - x ^= x << 13; - x ^= x >> 7; - x ^= x << 17; - state.set(x); - x - }) +fn read_f32_value(cursor: &mut SketchSlice<'_>) -> std::io::Result { + cursor.read_f32_le() } -fn random_bit() -> bool { - (next_u64() & 1) != 0 +fn read_f64_value(cursor: &mut SketchSlice<'_>) -> std::io::Result { + cursor.read_f64_le() } -fn shuffle(slice: &mut [T]) { - if slice.len() <= 1 { - return; +fn serialize_inner( + sketch: &DensitySketch, + value_size: usize, + write_value: SerializeValue, +) -> Vec { + let preamble_ints = if sketch.is_empty() { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_LONG + }; + let mut size_bytes = preamble_ints as usize * 4; + if !sketch.is_empty() { + for level in &sketch.levels { + size_bytes += 4 + (level.len() * sketch.dim as usize * value_size); + } } - for i in (1..slice.len()).rev() { - let j = (next_u64() % (i as u64 + 1)) as usize; - slice.swap(i, j); + let mut bytes = SketchBytes::with_capacity(size_bytes); + bytes.write_u8(preamble_ints); + bytes.write_u8(SERIAL_VERSION); + bytes.write_u8(DENSITY_FAMILY_ID); + let flags = if sketch.is_empty() { FLAGS_IS_EMPTY } else { 0 }; + bytes.write_u8(flags); + bytes.write_u16_le(sketch.k); + bytes.write_u16_le(0); + bytes.write_u32_le(sketch.dim); + + if sketch.is_empty() { + return bytes.into_bytes(); + } + + bytes.write_u32_le(sketch.num_retained); + bytes.write_u64_le(sketch.n); + for level in &sketch.levels { + bytes.write_u32_le(level.len() as u32); + for point in level { + for value in point { + write_value(&mut bytes, *value); + } + } + } + bytes.into_bytes() +} + +fn deserialize_inner( + bytes: &[u8], + kernel: K, + rng: R, + read_value: DeserializeValue, +) -> Result, Error> { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { + move |_| Error::insufficient_data(tag) + } + + let mut cursor = SketchSlice::new(bytes); + let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let k = cursor.read_u16_le().map_err(make_error("k"))?; + cursor.read_u16_le().map_err(make_error("unused"))?; + let dim = cursor.read_u32_le().map_err(make_error("dim"))?; + + if family_id != DENSITY_FAMILY_ID { + return Err(Error::invalid_family( + DENSITY_FAMILY_ID, + family_id, + "DensitySketch", + )); + } + if serial_version != SERIAL_VERSION { + return Err(Error::unsupported_serial_version( + SERIAL_VERSION, + serial_version, + )); + } + if k < 2 { + return Err(Error::new( + ErrorKind::InvalidArgument, + format!("k must be > 1. Found: {k}"), + )); } + + let is_empty = (flags & FLAGS_IS_EMPTY) != 0; + let expected_preamble = if is_empty { + PREAMBLE_INTS_SHORT + } else { + PREAMBLE_INTS_LONG + }; + if preamble_ints != expected_preamble { + return Err(Error::invalid_preamble_longs( + expected_preamble, + preamble_ints, + )); + } + if is_empty { + return Ok(DensitySketch::with_kernel_and_rng(k, dim, kernel, rng)); + } + + let num_retained = cursor.read_u32_le().map_err(make_error("num_retained"))?; + let n = cursor.read_u64_le().map_err(make_error("n"))?; + + let mut levels = Vec::new(); + let mut remaining = num_retained as i64; + while remaining > 0 { + let level_size = cursor.read_u32_le().map_err(make_error("level_size"))?; + let mut level = Vec::with_capacity(level_size as usize); + for _ in 0..level_size { + let mut point = Vec::with_capacity(dim as usize); + for _ in 0..dim { + point.push(read_value(&mut cursor).map_err(make_error("point"))?); + } + level.push(point); + } + remaining -= level_size as i64; + levels.push(level); + } + if remaining != 0 { + return Err(Error::deserial( + "invalid number of retained points while decoding density sketch", + )); + } + + Ok(DensitySketch { + kernel, + rng, + k, + dim, + num_retained, + n, + levels, + }) } diff --git a/datasketches/src/lib.rs b/datasketches/src/lib.rs index 5f1ea6e..78a30fa 100644 --- a/datasketches/src/lib.rs +++ b/datasketches/src/lib.rs @@ -31,6 +31,7 @@ compile_error!("datasketches does not support big-endian targets"); pub mod bloom; +pub mod common; pub mod countmin; pub mod density; pub mod error; diff --git a/datasketches/tests/density_sketch_test.rs b/datasketches/tests/density_sketch_test.rs index 53586cc..f85c998 100644 --- a/datasketches/tests/density_sketch_test.rs +++ b/datasketches/tests/density_sketch_test.rs @@ -19,6 +19,7 @@ use std::io::Cursor; use datasketches::density::DensityKernel; use datasketches::density::DensitySketch; +use datasketches::density::DensityValue; #[test] #[should_panic(expected = "operation is undefined for an empty sketch")] @@ -74,7 +75,7 @@ fn test_iterator() { let mut count = 0; for item in &sketch { count += 1; - assert_eq!(item.point.len(), sketch.dim() as usize); + assert_eq!(item.point().len(), sketch.dim() as usize); } assert_eq!(count as u32, sketch.num_retained()); } @@ -84,14 +85,18 @@ struct SphericalKernel { radius_squared: f32, } -impl DensityKernel for SphericalKernel { - fn evaluate(&self, left: &[f32], right: &[f32]) -> f32 { - let mut sum = 0.0f32; +impl DensityKernel for SphericalKernel { + fn evaluate(&self, left: &[T], right: &[T]) -> T { + let mut sum = 0.0f64; for (a, b) in left.iter().zip(right.iter()) { - let diff = a - b; + let diff = a.to_f64() - b.to_f64(); sum += diff * diff; } - if sum <= self.radius_squared { 1.0 } else { 0.0 } + if sum <= self.radius_squared as f64 { + T::from_f64(1.0) + } else { + T::from_f64(0.0) + } } } @@ -100,7 +105,7 @@ fn test_custom_kernel() { let kernel = SphericalKernel { radius_squared: 0.25, }; - let mut sketch: DensitySketch = DensitySketch::with_kernel(10, 3, Box::new(kernel)); + let mut sketch: DensitySketch = DensitySketch::with_kernel(10, 3, kernel); sketch.update(vec![1.0, 1.0, 1.0]); assert_eq!(sketch.estimate(&[1.001, 1.001, 1.001]), 1.0); @@ -115,7 +120,7 @@ fn test_custom_kernel() { let mut count = 0; for item in &sketch { count += 1; - assert_eq!(item.point.len(), sketch.dim() as usize); + assert_eq!(item.point().len(), sketch.dim() as usize); } assert_eq!(count as u32, sketch.num_retained()); } @@ -134,8 +139,8 @@ fn test_serialize_empty() { let mut cursor = Cursor::new(Vec::new()); sketch.serialize_to_writer(&mut cursor).unwrap(); - cursor.set_position(0); - let decoded = DensitySketch::::deserialize_from_reader(&mut cursor).unwrap(); + let bytes = cursor.into_inner(); + let decoded = DensitySketch::::deserialize(&bytes).unwrap(); assert!(decoded.is_empty()); assert!(!decoded.is_estimation_mode()); assert_eq!(sketch.k(), decoded.k()); @@ -167,8 +172,8 @@ fn test_serialize_bytes() { let mut iter_left = sketch.iter(); let mut iter_right = decoded.iter(); while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { - assert_eq!(left.point[0], right.point[0]); - assert_eq!(left.weight, right.weight); + assert_eq!(left.point()[0], right.point()[0]); + assert_eq!(left.weight(), right.weight()); } let n = 1031; @@ -189,8 +194,8 @@ fn test_serialize_bytes() { let mut iter_left = sketch.iter(); let mut iter_right = decoded.iter(); while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { - assert_eq!(left.point[0], right.point[0]); - assert_eq!(left.weight, right.weight); + assert_eq!(left.point()[0], right.point()[0]); + assert_eq!(left.weight(), right.weight()); } } @@ -208,8 +213,8 @@ fn test_serialize_stream() { let mut cursor = Cursor::new(Vec::new()); sketch.serialize_to_writer(&mut cursor).unwrap(); - cursor.set_position(0); - let decoded = DensitySketch::::deserialize_from_reader(&mut cursor).unwrap(); + let bytes = cursor.into_inner(); + let decoded = DensitySketch::::deserialize(&bytes).unwrap(); assert!(!decoded.is_empty()); assert!(!decoded.is_estimation_mode()); assert_eq!(sketch.k(), decoded.k()); @@ -219,10 +224,10 @@ fn test_serialize_stream() { let mut iter_left = sketch.iter(); let mut iter_right = decoded.iter(); while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { - assert_eq!(left.point[0], right.point[0]); - assert_eq!(left.weight, right.weight); - assert_eq!(left.point[1], right.point[1]); - assert_eq!(left.point[2], right.point[2]); + assert_eq!(left.point()[0], right.point()[0]); + assert_eq!(left.weight(), right.weight()); + assert_eq!(left.point()[1], right.point()[1]); + assert_eq!(left.point()[2], right.point()[2]); } let n = 1031; @@ -234,8 +239,8 @@ fn test_serialize_stream() { let mut cursor = Cursor::new(Vec::new()); sketch.serialize_to_writer(&mut cursor).unwrap(); - cursor.set_position(0); - let decoded = DensitySketch::::deserialize_from_reader(&mut cursor).unwrap(); + let bytes = cursor.into_inner(); + let decoded = DensitySketch::::deserialize(&bytes).unwrap(); assert!(!decoded.is_empty()); assert!(decoded.is_estimation_mode()); assert_eq!(sketch.k(), decoded.k()); @@ -245,9 +250,9 @@ fn test_serialize_stream() { let mut iter_left = sketch.iter(); let mut iter_right = decoded.iter(); while let (Some(left), Some(right)) = (iter_left.next(), iter_right.next()) { - assert_eq!(left.point[0], right.point[0]); - assert_eq!(left.weight, right.weight); - assert_eq!(left.point[1], right.point[1]); - assert_eq!(left.point[2], right.point[2]); + assert_eq!(left.point()[0], right.point()[0]); + assert_eq!(left.weight(), right.weight()); + assert_eq!(left.point()[1], right.point()[1]); + assert_eq!(left.point()[2], right.point()[2]); } }