diff --git a/Cargo.lock b/Cargo.lock index cfb7542..c88f4fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1323,7 +1323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.53.3", + "windows-targets 0.48.5", ] [[package]] @@ -1382,9 +1382,9 @@ dependencies = [ [[package]] name = "magnus" -version = "0.6.4" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1597ef40aa8c36be098249e82c9a20cf7199278ac1c1a1a995eeead6a184479" +checksum = "3b36a5b126bbe97eb0d02d07acfeb327036c6319fd816139a49824a83b7f9012" dependencies = [ "magnus-macros", "rb-sys", @@ -1394,9 +1394,9 @@ dependencies = [ [[package]] name = "magnus-macros" -version = "0.6.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5968c820e2960565f647819f5928a42d6e874551cab9d88d75e3e0660d7f71e3" +checksum = "47607461fd8e1513cb4f2076c197d8092d921a1ea75bd08af97398f593751892" dependencies = [ "proc-macro2", "quote", @@ -1970,18 +1970,18 @@ dependencies = [ [[package]] name = "rb-sys" -version = "0.9.117" +version = "0.9.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f900d1ce4629a2ebffaf5de74bd8f9c1188d4c5ed406df02f97e22f77a006f44" +checksum = "c85c4188462601e2aa1469def389c17228566f82ea72f137ed096f21591bc489" dependencies = [ "rb-sys-build", ] [[package]] name = "rb-sys-build" -version = "0.9.117" +version = "0.9.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef1e9c857028f631056bcd6d88cec390c751e343ce2223ddb26d23eb4a151d59" +checksum = "568068db4102230882e6d4ae8de6632e224ca75fe5970f6e026a04e91ed635d3" dependencies = [ "bindgen", "lazy_static", @@ -1994,9 +1994,9 @@ dependencies = [ [[package]] name = "rb-sys-env" -version = "0.1.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35802679f07360454b418a5d1735c89716bde01d35b1560fc953c1415a0b3bb" +checksum = "cca7ad6a7e21e72151d56fe2495a259b5670e204c3adac41ee7ef676ea08117a" [[package]] name = "redox_syscall" diff --git a/clusterkit.gemspec b/clusterkit.gemspec index a2492eb..077516f 100644 --- a/clusterkit.gemspec +++ b/clusterkit.gemspec @@ -33,6 +33,7 @@ Gem::Specification.new do |spec| spec.add_dependency "rb_sys", "~> 0.9" # Development dependencies + spec.add_development_dependency "benchmark" spec.add_development_dependency "csv" spec.add_development_dependency "rake", "~> 13.0" spec.add_development_dependency "rake-compiler", "~> 1.2" diff --git a/ext/clusterkit/Cargo.toml b/ext/clusterkit/Cargo.toml index d2f44eb..f8e5e51 100644 --- a/ext/clusterkit/Cargo.toml +++ b/ext/clusterkit/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -magnus = { version = "0.6", features = ["embed"] } +magnus = { version = "0.8", features = ["embed"] } annembed = { git = "https://github.com/scientist-labs/annembed", tag = "clusterkit-0.1.1" } hnsw_rs = { git = "https://github.com/scientist-labs/hnswlib-rs", tag = "clusterkit-0.1.0" } hdbscan = "0.11" diff --git a/ext/clusterkit/src/clustering.rs b/ext/clusterkit/src/clustering.rs index 1c3a56e..1be5757 100644 --- a/ext/clusterkit/src/clustering.rs +++ b/ext/clusterkit/src/clustering.rs @@ -1,50 +1,52 @@ -use magnus::{function, prelude::*, Error, Value, RArray, Integer}; +use magnus::{function, prelude::*, Error, Value, RArray, Ruby}; use ndarray::{Array1, Array2, ArrayView1, Axis}; use rand::prelude::*; use rand::rngs::StdRng; use rand::SeedableRng; -use crate::utils::{ruby_array_to_ndarray2}; +use crate::utils::ruby_array_to_ndarray2; mod hdbscan_wrapper; pub fn init(parent: &magnus::RModule) -> Result<(), Error> { let clustering_module = parent.define_module("Clustering")?; - + clustering_module.define_singleton_method( "kmeans_rust", function!(kmeans, 4), )?; - + clustering_module.define_singleton_method( "kmeans_predict_rust", function!(kmeans_predict, 2), )?; - + // Initialize HDBSCAN functions hdbscan_wrapper::init(&clustering_module)?; - + Ok(()) } /// Perform K-means clustering /// Returns (labels, centroids, inertia) fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option) -> Result<(RArray, RArray, f64), Error> { + let ruby = Ruby::get().unwrap(); + // Convert Ruby array to ndarray using shared helper let data_array = ruby_array_to_ndarray2(data)?; let (n_samples, n_features) = data_array.dim(); - + if k > n_samples { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), format!("k ({}) cannot be larger than number of samples ({})", k, n_samples), )); } - + // Initialize centroids using K-means++ let mut centroids = kmeans_plusplus(&data_array, k, random_seed)?; let mut labels = vec![0usize; n_samples]; let mut prev_labels = vec![0usize; n_samples]; - + // K-means iterations for iteration in 0..max_iter { // Assign points to nearest centroid @@ -53,7 +55,7 @@ fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option) -> R let point = data_array.row(i); let mut min_dist = f64::INFINITY; let mut best_cluster = 0; - + for (j, centroid) in centroids.axis_iter(Axis(0)).enumerate() { let dist = euclidean_distance(&point, ¢roid); if dist < min_dist { @@ -61,38 +63,38 @@ fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option) -> R best_cluster = j; } } - + if labels[i] != best_cluster { changed = true; } labels[i] = best_cluster; } - + // Check for convergence if !changed && iteration > 0 { break; } - + // Update centroids for j in 0..k { let mut sum = Array1::::zeros(n_features); let mut count = 0; - + for i in 0..n_samples { if labels[i] == j { sum += &data_array.row(i); count += 1; } } - + if count > 0 { centroids.row_mut(j).assign(&(sum / count as f64)); } } - + prev_labels.clone_from(&labels); } - + // Calculate inertia (sum of squared distances to nearest centroid) let mut inertia = 0.0; for i in 0..n_samples { @@ -100,44 +102,43 @@ fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option) -> R let centroid = centroids.row(labels[i]); inertia += euclidean_distance(&point, ¢roid).powi(2); } - + // Convert results to Ruby arrays - let ruby = magnus::Ruby::get().unwrap(); - let labels_array = RArray::new(); + let labels_array = ruby.ary_new(); for label in labels { - labels_array.push(Integer::from_value(ruby.eval(&format!("{}", label)).unwrap()).unwrap())?; + labels_array.push(ruby.integer_from_i64(label as i64))?; } - - let centroids_array = RArray::new(); + + let centroids_array = ruby.ary_new(); for i in 0..k { - let row_array = RArray::new(); + let row_array = ruby.ary_new(); for j in 0..n_features { row_array.push(centroids[[i, j]])?; } centroids_array.push(row_array)?; } - + Ok((labels_array, centroids_array, inertia)) } /// Predict cluster labels for new data given centroids fn kmeans_predict(data: Value, centroids: Value) -> Result { + let ruby = Ruby::get().unwrap(); + // Convert inputs using shared helpers let data_matrix = ruby_array_to_ndarray2(data)?; let centroids_matrix = ruby_array_to_ndarray2(centroids)?; - + let (n_samples, _) = data_matrix.dim(); - let (_k, _) = centroids_matrix.dim(); - + // Predict labels - let ruby = magnus::Ruby::get().unwrap(); - let labels_array = RArray::new(); - + let labels_array = ruby.ary_new(); + for i in 0..n_samples { let point = data_matrix.row(i); let mut min_dist = f64::INFINITY; let mut best_cluster = 0; - + for (j, centroid) in centroids_matrix.axis_iter(Axis(0)).enumerate() { let dist = euclidean_distance(&point, ¢roid); if dist < min_dist { @@ -145,10 +146,10 @@ fn kmeans_predict(data: Value, centroids: Value) -> Result { best_cluster = j; } } - - labels_array.push(Integer::from_value(ruby.eval(&format!("{}", best_cluster)).unwrap()).unwrap())?; + + labels_array.push(ruby.integer_from_i64(best_cluster as i64))?; } - + Ok(labels_array) } @@ -156,28 +157,26 @@ fn kmeans_predict(data: Value, centroids: Value) -> Result { fn kmeans_plusplus(data: &Array2, k: usize, random_seed: Option) -> Result, Error> { let n_samples = data.nrows(); let n_features = data.ncols(); - + // Use seeded RNG if seed is provided, otherwise use thread_rng let mut rng: Box = match random_seed { Some(seed) => { - // Convert i64 to u64 for seeding (negative numbers wrap around) let seed_u64 = seed as u64; Box::new(StdRng::seed_from_u64(seed_u64)) }, None => Box::new(thread_rng()), }; - + let mut centroids = Array2::::zeros((k, n_features)); - + // Choose first centroid randomly let first_idx = rng.gen_range(0..n_samples); centroids.row_mut(0).assign(&data.row(first_idx)); - + // Choose remaining centroids for i in 1..k { let mut distances = vec![f64::INFINITY; n_samples]; - - // Calculate distance to nearest centroid for each point + for j in 0..n_samples { for c in 0..i { let dist = euclidean_distance(&data.row(j), ¢roids.row(c)); @@ -186,25 +185,20 @@ fn kmeans_plusplus(data: &Array2, k: usize, random_seed: Option) -> Re } } } - - // Convert distances to probabilities + let total: f64 = distances.iter().map(|d| d * d).sum(); if total == 0.0 { - // All points are identical or we've selected duplicates - // Just use sequential points as centroids if i < n_samples { centroids.row_mut(i).assign(&data.row(i)); } else { - // Reuse first point if we run out centroids.row_mut(i).assign(&data.row(0)); } continue; } - - // Choose next centroid with probability proportional to squared distance + let mut cumsum = 0.0; let rand_val: f64 = rng.gen::() * total; - + for j in 0..n_samples { cumsum += distances[j] * distances[j]; if cumsum >= rand_val { @@ -213,7 +207,7 @@ fn kmeans_plusplus(data: &Array2, k: usize, random_seed: Option) -> Re } } } - + Ok(centroids) } @@ -224,4 +218,4 @@ fn euclidean_distance(a: &ArrayView1, b: &ArrayView1) -> f64 { .map(|(x, y)| (x - y).powi(2)) .sum::() .sqrt() -} \ No newline at end of file +} diff --git a/ext/clusterkit/src/clustering/hdbscan_wrapper.rs b/ext/clusterkit/src/clustering/hdbscan_wrapper.rs index 24fe084..1cf2dab 100644 --- a/ext/clusterkit/src/clustering/hdbscan_wrapper.rs +++ b/ext/clusterkit/src/clustering/hdbscan_wrapper.rs @@ -1,4 +1,4 @@ -use magnus::{function, prelude::*, Error, Value, RArray, RHash, Integer}; +use magnus::{function, prelude::*, Error, Value, RHash, Ruby}; use hdbscan::{Hdbscan, HdbscanHyperParams}; use crate::utils::ruby_array_to_vec_vec_f64; @@ -10,75 +10,62 @@ pub fn hdbscan_fit( min_cluster_size: usize, metric: String, ) -> Result { + let ruby = Ruby::get().unwrap(); + // Convert Ruby array to Vec> using shared helper let data_vec = ruby_array_to_vec_vec_f64(data)?; let n_samples = data_vec.len(); - - // Note: hdbscan crate doesn't support custom metrics directly - // We'll use the default Euclidean distance for now + if metric != "euclidean" && metric != "l2" { eprintln!("Warning: Current hdbscan version only supports Euclidean distance. Using Euclidean."); } - + // Adjust parameters to avoid index out of bounds errors - // The hdbscan crate has issues when min_samples >= n_samples let adjusted_min_samples = min_samples.min(n_samples.saturating_sub(1)).max(1); let adjusted_min_cluster_size = min_cluster_size.min(n_samples).max(2); - + // Create hyperparameters let hyper_params = HdbscanHyperParams::builder() .min_cluster_size(adjusted_min_cluster_size) .min_samples(adjusted_min_samples) .build(); - + // Create HDBSCAN instance and run clustering let clusterer = Hdbscan::new(&data_vec, hyper_params); - - // Run the clustering algorithm - cluster() returns Result, HdbscanError> + let labels = clusterer.cluster().map_err(|e| { Error::new( - magnus::exception::runtime_error(), + ruby.exception_runtime_error(), format!("HDBSCAN clustering failed: {:?}", e) ) })?; - + // Convert results to Ruby types - let ruby = magnus::Ruby::get().unwrap(); - let result = RHash::new(); - - // Convert labels (i32 to Ruby Integer, -1 for noise) - let labels_array = RArray::new(); + let result = ruby.hash_new(); + + let labels_array = ruby.ary_new(); for &label in labels.iter() { - labels_array.push(Integer::from_value( - ruby.eval(&format!("{}", label)).unwrap() - ).unwrap())?; + labels_array.push(ruby.integer_from_i64(label as i64))?; } result.aset("labels", labels_array)?; - - // For now, we'll create dummy probabilities and outlier scores - // since the basic hdbscan crate doesn't provide these - // In the future, we could calculate these ourselves or use a more advanced implementation - - // Create probabilities array (all 1.0 for clustered points, 0.0 for noise) - let probs_array = RArray::new(); + + let probs_array = ruby.ary_new(); for &label in labels.iter() { let prob = if label == -1 { 0.0 } else { 1.0 }; probs_array.push(prob)?; } result.aset("probabilities", probs_array)?; - - // Create outlier scores array (0.0 for clustered points, 1.0 for noise) - let outlier_array = RArray::new(); + + let outlier_array = ruby.ary_new(); for &label in labels.iter() { let score = if label == -1 { 1.0 } else { 0.0 }; outlier_array.push(score)?; } result.aset("outlier_scores", outlier_array)?; - - // Create empty cluster persistence hash for now - let persistence_hash = RHash::new(); + + let persistence_hash = ruby.hash_new(); result.aset("cluster_persistence", persistence_hash)?; - + Ok(result) } @@ -88,6 +75,6 @@ pub fn init(clustering_module: &magnus::RModule) -> Result<(), Error> { "hdbscan_rust", function!(hdbscan_fit, 4), )?; - + Ok(()) -} \ No newline at end of file +} diff --git a/ext/clusterkit/src/embedder.rs b/ext/clusterkit/src/embedder.rs index 94fbfd9..06e8039 100644 --- a/ext/clusterkit/src/embedder.rs +++ b/ext/clusterkit/src/embedder.rs @@ -1,4 +1,4 @@ -use magnus::{Error, RArray, RHash, Value, TryConvert, Integer, Module, Object}; +use magnus::{Error, RArray, RHash, Value, TryConvert, Integer, Module, Object, Ruby}; use magnus::value::ReprValue; use hnsw_rs::prelude::*; use annembed::prelude::*; @@ -21,7 +21,8 @@ struct SavedUMAPModel { } pub fn init(parent: &magnus::RModule) -> Result<(), Error> { - let umap_class = parent.define_class("RustUMAP", magnus::class::object())?; + let ruby = Ruby::get().unwrap(); + let umap_class = parent.define_class("RustUMAP", ruby.class_object())?; umap_class.define_singleton_method("new", magnus::function!(RustUMAP::new, 1))?; umap_class.define_singleton_method("load_model", magnus::function!(RustUMAP::load_model, 1))?; @@ -40,15 +41,15 @@ struct RustUMAP { random_seed: Option, nb_grad_batch: usize, nb_sampling_by_edge: usize, - // Store the training data and embeddings for transform approximation - // Use RefCell for interior mutability training_data: RefCell>>>, training_embeddings: RefCell>>>, } impl RustUMAP { fn new(options: RHash) -> Result { - let n_components = match options.lookup::<_, Value>(magnus::Symbol::new("n_components")) { + let ruby = Ruby::get().unwrap(); + + let n_components = match options.lookup::<_, Value>(ruby.to_symbol("n_components")) { Ok(val) => { if val.is_nil() { 2 @@ -61,7 +62,7 @@ impl RustUMAP { Err(_) => 2, }; - let n_neighbors = match options.lookup::<_, Value>(magnus::Symbol::new("n_neighbors")) { + let n_neighbors = match options.lookup::<_, Value>(ruby.to_symbol("n_neighbors")) { Ok(val) => { if val.is_nil() { 15 @@ -74,7 +75,7 @@ impl RustUMAP { Err(_) => 15, }; - let random_seed = match options.lookup::<_, Value>(magnus::Symbol::new("random_seed")) { + let random_seed = match options.lookup::<_, Value>(ruby.to_symbol("random_seed")) { Ok(val) => { if val.is_nil() { None @@ -87,10 +88,10 @@ impl RustUMAP { Err(_) => None, }; - let nb_grad_batch = match options.lookup::<_, Value>(magnus::Symbol::new("nb_grad_batch")) { + let nb_grad_batch = match options.lookup::<_, Value>(ruby.to_symbol("nb_grad_batch")) { Ok(val) => { if val.is_nil() { - 10 // Default value + 10 } else { Integer::try_convert(val) .map(|i| i.to_u32().unwrap_or(10) as usize) @@ -99,11 +100,11 @@ impl RustUMAP { } Err(_) => 10, }; - - let nb_sampling_by_edge = match options.lookup::<_, Value>(magnus::Symbol::new("nb_sampling_by_edge")) { + + let nb_sampling_by_edge = match options.lookup::<_, Value>(ruby.to_symbol("nb_sampling_by_edge")) { Ok(val) => { if val.is_nil() { - 8 // Default value + 8 } else { Integer::try_convert(val) .map(|i| i.to_u32().unwrap_or(8) as usize) @@ -125,6 +126,8 @@ impl RustUMAP { } fn fit_transform(&self, data: Value) -> Result { + let ruby = Ruby::get().unwrap(); + // Convert Ruby array to Rust Vec> using shared helper let data_f32 = ruby_array_to_vec_vec_f32(data)?; @@ -149,9 +152,7 @@ impl RustUMAP { .enumerate() .map(|(i, v)| (v, i)) .collect(); - - // Use serial_insert for reproducibility when seed is provided, - // parallel_insert for performance when no seed + if self.random_seed.is_some() { hnsw.serial_insert(&data_with_id); } else { @@ -160,36 +161,34 @@ impl RustUMAP { // Create KGraph from HNSW let kgraph: annembed::fromhnsw::kgraph::KGraph = annembed::fromhnsw::kgraph::kgraph_from_hnsw_all(&hnsw, self.n_neighbors) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; // Set up embedding parameters let mut embed_params = EmbedderParams::default(); embed_params.asked_dim = self.n_components; - embed_params.nb_grad_batch = self.nb_grad_batch; // Configurable from Ruby + embed_params.nb_grad_batch = self.nb_grad_batch; embed_params.scale_rho = 1.; embed_params.beta = 1.; embed_params.b = 1.; embed_params.grad_step = 1.; - embed_params.nb_sampling_by_edge = self.nb_sampling_by_edge; // Configurable from Ruby - // Enable diffusion map initialization (annembed now has fallback to random if it fails) + embed_params.nb_sampling_by_edge = self.nb_sampling_by_edge; embed_params.dmap_init = true; - embed_params.random_seed = self.random_seed; // Pass seed through to annembed + embed_params.random_seed = self.random_seed; // Create embedder and perform embedding let mut embedder = Embedder::new(&kgraph, embed_params); let embed_result = embedder.embed() - .map_err(|e| Error::new(magnus::exception::runtime_error(), + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Embedding failed: {}", e)))?; if embed_result == 0 { - return Err(Error::new(magnus::exception::runtime_error(), "No points were embedded")); + return Err(Error::new(ruby.exception_runtime_error(), "No points were embedded")); } // Get embedded data let embedded_array = embedder.get_embedded_reindexed(); - // Store results in a simpler format let mut embeddings = Vec::new(); for i in 0..embedded_array.nrows() { let mut row = Vec::new(); @@ -198,13 +197,15 @@ impl RustUMAP { } embeddings.push(row); } + // Store the training data and embeddings for future transforms *self.training_data.borrow_mut() = Some(data_f32.clone()); *self.training_embeddings.borrow_mut() = Some(embeddings.clone()); + // Convert result back to Ruby array - let result = RArray::new(); + let result = ruby.ary_new(); for embedding in &embeddings { - let row = RArray::new(); + let row = ruby.ary_new(); for &val in embedding { row.push(val)?; } @@ -213,16 +214,15 @@ impl RustUMAP { Ok(result) } - // Save the full model (training data + embeddings + params) for future transforms fn save_model(&self, path: String) -> Result<(), Error> { - // Check if we have training data + let ruby = Ruby::get().unwrap(); let training_data = self.training_data.borrow(); let training_embeddings = self.training_embeddings.borrow(); let training_data_ref = training_data.as_ref() - .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "No model to save. Run fit_transform first."))?; + .ok_or_else(|| Error::new(ruby.exception_runtime_error(), "No model to save. Run fit_transform first."))?; let training_embeddings_ref = training_embeddings.as_ref() - .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "No embeddings to save."))?; + .ok_or_else(|| Error::new(ruby.exception_runtime_error(), "No embeddings to save."))?; let saved_model = SavedUMAPModel { n_components: self.n_components, @@ -234,28 +234,29 @@ impl RustUMAP { }; let serialized = bincode::serialize(&saved_model) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; let mut file = File::create(&path) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; file.write_all(&serialized) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; Ok(()) } - // Load a full model for transforming new data fn load_model(path: String) -> Result { + let ruby = Ruby::get().unwrap(); + let mut file = File::open(&path) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; let mut buffer = Vec::new(); file.read_to_end(&mut buffer) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; let saved_model: SavedUMAPModel = bincode::deserialize(&buffer) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e.to_string()))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), e.to_string()))?; Ok(RustUMAP { n_components: saved_model.n_components, @@ -268,43 +269,36 @@ impl RustUMAP { }) } - // Transform new data using k-NN approximation with the training data fn transform(&self, data: Value) -> Result { - // Get training data + let ruby = Ruby::get().unwrap(); let training_data = self.training_data.borrow(); let training_embeddings = self.training_embeddings.borrow(); let training_data_ref = training_data.as_ref() - .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "No model loaded. Load a model or run fit_transform first."))?; + .ok_or_else(|| Error::new(ruby.exception_runtime_error(), "No model loaded. Load a model or run fit_transform first."))?; let training_embeddings_ref = training_embeddings.as_ref() - .ok_or_else(|| Error::new(magnus::exception::runtime_error(), "No embeddings available."))?; + .ok_or_else(|| Error::new(ruby.exception_runtime_error(), "No embeddings available."))?; - // Convert input data to Rust format using shared helper let new_data = ruby_array_to_vec_vec_f32(data)?; - // For each new point, find k nearest neighbors in training data - // and average their embeddings (weighted by distance) let k = self.n_neighbors.min(training_data_ref.len()); - let result = RArray::new(); + let result = ruby.ary_new(); for new_point in &new_data { - // Calculate distances to all training points let mut distances: Vec<(f32, usize)> = Vec::new(); for (idx, train_point) in training_data_ref.iter().enumerate() { let dist = euclidean_distance(new_point, train_point); distances.push((dist, idx)); } - // Sort by distance and take k nearest distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); let k_nearest = &distances[..k]; - // Weighted average of k nearest embeddings let mut avg_embedding = vec![0.0; self.n_components]; let mut total_weight = 0.0; for &(dist, idx) in k_nearest { - let weight = 1.0 / (dist as f64 + 0.001); // Inverse distance weighting + let weight = 1.0 / (dist as f64 + 0.001); total_weight += weight; for (i, &val) in training_embeddings_ref[idx].iter().enumerate() { @@ -312,13 +306,11 @@ impl RustUMAP { } } - // Normalize for val in &mut avg_embedding { *val /= total_weight; } - // Convert to Ruby array - let row = RArray::new(); + let row = ruby.ary_new(); for val in avg_embedding { row.push(val)?; } @@ -335,4 +327,4 @@ fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 { .map(|(x, y)| (x - y).powi(2)) .sum::() .sqrt() -} \ No newline at end of file +} diff --git a/ext/clusterkit/src/hnsw.rs b/ext/clusterkit/src/hnsw.rs index 99e60a5..3d03943 100644 --- a/ext/clusterkit/src/hnsw.rs +++ b/ext/clusterkit/src/hnsw.rs @@ -1,10 +1,10 @@ use magnus::{ - class, exception, function, method, prelude::*, - Error, Float, Integer, RArray, RHash, RString, Symbol, Value, value, TryConvert, r_hash::ForEach + function, method, prelude::*, + Error, Float, Integer, RArray, RHash, RString, Symbol, Value, TryConvert, Ruby, + r_hash::ForEach, }; use hnsw_rs::prelude::*; use hnsw_rs::hnswio::HnswIo; -// use ndarray::Array1; // Not used currently use std::collections::HashMap; use std::sync::{Arc, Mutex}; use serde::{Serialize, Deserialize}; @@ -30,7 +30,7 @@ pub struct HnswIndex { } #[derive(Clone, Copy)] -#[allow(dead_code)] // These variants will be implemented in the future +#[allow(dead_code)] enum DistanceType { Euclidean, Cosine, @@ -38,88 +38,83 @@ enum DistanceType { } impl HnswIndex { - // Initialize a new HNSW index pub fn new(kwargs: RHash) -> Result { - // Parse arguments - let dim_opt: Option = kwargs.delete(Symbol::new("dim"))?; - let dim_value = dim_opt.ok_or_else(|| Error::new(exception::arg_error(), "dim is required"))?; + let ruby = Ruby::get().unwrap(); + + let dim_opt: Option = kwargs.delete(ruby.to_symbol("dim"))?; + let dim_value = dim_opt.ok_or_else(|| Error::new(ruby.exception_arg_error(), "dim is required"))?; let dim: usize = TryConvert::try_convert(dim_value) - .map_err(|_| Error::new(exception::arg_error(), "dim must be an integer"))?; - - // Validate dimension + .map_err(|_| Error::new(ruby.exception_arg_error(), "dim must be an integer"))?; + if dim == 0 { - return Err(Error::new(exception::arg_error(), "dim must be a positive integer (got 0)")); + return Err(Error::new(ruby.exception_arg_error(), "dim must be a positive integer (got 0)")); } - - let space: String = if let Some(v) = kwargs.delete(Symbol::new("space"))? { - // Convert Ruby symbol to string properly + + let space: String = if let Some(v) = kwargs.delete(ruby.to_symbol("space"))? { if let Ok(sym) = Symbol::try_convert(v) { sym.name()?.to_string() } else if let Ok(s) = String::try_convert(v) { s } else { return Err(Error::new( - exception::type_error(), + ruby.exception_type_error(), "space must be a string or symbol" )); } } else { "euclidean".to_string() }; - - let max_elements: usize = if let Some(v) = kwargs.delete(Symbol::new("max_elements"))? { + + let max_elements: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("max_elements"))? { TryConvert::try_convert(v).unwrap_or(10_000) } else { 10_000 }; - - let m: usize = if let Some(v) = kwargs.delete(Symbol::new("M"))? { + + let m: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("M"))? { TryConvert::try_convert(v).unwrap_or(16) } else { 16 }; - - let ef_construction: usize = if let Some(v) = kwargs.delete(Symbol::new("ef_construction"))? { + + let ef_construction: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("ef_construction"))? { TryConvert::try_convert(v).unwrap_or(200) } else { 200 }; - - let random_seed: Option = if let Some(v) = kwargs.delete(Symbol::new("random_seed"))? { + + let random_seed: Option = if let Some(v) = kwargs.delete(ruby.to_symbol("random_seed"))? { TryConvert::try_convert(v).ok() } else { None }; - - // Validate and convert space parameter - // For now, only support Euclidean distance + let distance_type = match space.as_str() { "euclidean" => DistanceType::Euclidean, "cosine" => { return Err(Error::new( - exception::runtime_error(), + ruby.exception_runtime_error(), "Cosine distance is not yet implemented, please use :euclidean" )); }, "inner_product" => { return Err(Error::new( - exception::runtime_error(), + ruby.exception_runtime_error(), "Inner product distance is not yet implemented, please use :euclidean" )); }, _ => return Err(Error::new( - exception::arg_error(), + ruby.exception_arg_error(), format!("space must be :euclidean, :cosine, or :inner_product (got: {})", space) )), }; - - // Create HNSW instance with Euclidean distance + let hnsw = if let Some(seed) = random_seed { Hnsw::::new_with_seed(m, max_elements, 16, ef_construction, DistL2, seed) } else { Hnsw::::new(m, max_elements, 16, ef_construction, DistL2) }; - + Ok(Self { hnsw: Arc::new(Mutex::new(hnsw)), dim, @@ -130,14 +125,13 @@ impl HnswIndex { ef_search: Arc::new(Mutex::new(ef_construction)), }) } - - // Add a single item to the index + pub fn add_item(&self, vector: RArray, kwargs: RHash) -> Result { - // Parse vector - let vec_data = parse_vector(vector, self.dim)?; - - // Get or generate label - let label: String = if let Some(v) = kwargs.delete(Symbol::new("label"))? { + let ruby = Ruby::get().unwrap(); + + let vec_data = parse_vector(&ruby, vector, self.dim)?; + + let label: String = if let Some(v) = kwargs.delete(ruby.to_symbol("label"))? { TryConvert::try_convert(v).unwrap_or_else(|_| { let mut id = self.current_id.lock().unwrap(); let label = id.to_string(); @@ -150,33 +144,30 @@ impl HnswIndex { *id += 1; label }; - - // Get metadata if provided - let metadata: Option> = if let Some(v) = kwargs.delete(Symbol::new("metadata"))? { - Some(parse_metadata(v)?) + + let metadata: Option> = if let Some(v) = kwargs.delete(ruby.to_symbol("metadata"))? { + Some(parse_metadata(&ruby, v)?) } else { None }; - - // Get internal ID for this item + let internal_id = { let mut label_map = self.label_to_id.lock().unwrap(); let mut current_id = self.current_id.lock().unwrap(); - + if label_map.contains_key(&label) { return Err(Error::new( - exception::arg_error(), + ruby.exception_arg_error(), format!("Label '{}' already exists in index", label) )); } - + let id = *current_id; label_map.insert(label.clone(), id); *current_id += 1; id }; - - // Store metadata + { let mut metadata_store = self.metadata_store.lock().unwrap(); metadata_store.insert(internal_id, ItemMetadata { @@ -184,39 +175,38 @@ impl HnswIndex { metadata, }); } - - // Add to HNSW + { let hnsw = self.hnsw.lock().unwrap(); hnsw.insert((&vec_data, internal_id)); } - - Ok(value::qnil().as_value()) + + Ok(ruby.qnil().as_value()) } - - // Add multiple items in batch + pub fn add_batch(&self, vectors: RArray, kwargs: RHash) -> Result { - let parallel: bool = if let Some(v) = kwargs.delete(Symbol::new("parallel"))? { + let ruby = Ruby::get().unwrap(); + + let parallel: bool = if let Some(v) = kwargs.delete(ruby.to_symbol("parallel"))? { TryConvert::try_convert(v).unwrap_or(true) } else { true }; - - let labels: Option = if let Some(v) = kwargs.delete(Symbol::new("labels"))? { + + let labels: Option = if let Some(v) = kwargs.delete(ruby.to_symbol("labels"))? { TryConvert::try_convert(v).ok() } else { None }; - - // Parse all vectors + let mut data_points: Vec<(Vec, usize)> = Vec::new(); let mut metadata_entries: Vec<(usize, ItemMetadata)> = Vec::new(); - - for (i, vector) in vectors.each().enumerate() { - let vector: RArray = TryConvert::try_convert(vector?)?; - let vec_data = parse_vector(vector, self.dim)?; - - // Get or generate label + + let len = vectors.len(); + for i in 0..len { + let vector: RArray = vectors.entry(i as isize)?; + let vec_data = parse_vector(&ruby, vector, self.dim)?; + let label = if let Some(ref labels_array) = labels { labels_array.entry::(i as isize)? } else { @@ -225,41 +215,38 @@ impl HnswIndex { *id += 1; label }; - - // Get internal ID + let internal_id = { let mut label_map = self.label_to_id.lock().unwrap(); let mut current_id = self.current_id.lock().unwrap(); - + if label_map.contains_key(&label) { return Err(Error::new( - exception::arg_error(), + ruby.exception_arg_error(), format!("Label '{}' already exists in index", label) )); } - + let id = *current_id; label_map.insert(label.clone(), id); *current_id += 1; id }; - + data_points.push((vec_data, internal_id)); metadata_entries.push((internal_id, ItemMetadata { label, metadata: None, })); } - - // Store metadata + { let mut metadata_store = self.metadata_store.lock().unwrap(); for (id, metadata) in metadata_entries { metadata_store.insert(id, metadata); } } - - // Insert into HNSW + { let hnsw = self.hnsw.lock().unwrap(); if parallel { @@ -271,57 +258,54 @@ impl HnswIndex { } } } - - Ok(value::qnil().as_value()) + + Ok(ruby.qnil().as_value()) } - - // Search for k nearest neighbors + pub fn search(&self, query: RArray, kwargs: RHash) -> Result { - let k: usize = if let Some(v) = kwargs.delete(Symbol::new("k"))? { + let ruby = Ruby::get().unwrap(); + + let k: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("k"))? { TryConvert::try_convert(v).unwrap_or(10) } else { 10 }; - - let include_distances: bool = if let Some(v) = kwargs.delete(Symbol::new("include_distances"))? { + + let include_distances: bool = if let Some(v) = kwargs.delete(ruby.to_symbol("include_distances"))? { TryConvert::try_convert(v).unwrap_or(false) } else { false }; - - // Parse query vector - let query_vec = parse_vector(query, self.dim)?; - - // Set search ef if provided - if let Some(v) = kwargs.delete(Symbol::new("ef"))? { + + let query_vec = parse_vector(&ruby, query, self.dim)?; + + if let Some(v) = kwargs.delete(ruby.to_symbol("ef"))? { if let Ok(ef) = TryConvert::try_convert(v) as Result { let mut ef_search = self.ef_search.lock().unwrap(); *ef_search = ef; } } - - // Perform search + let neighbors = { let hnsw = self.hnsw.lock().unwrap(); let ef_search = self.ef_search.lock().unwrap(); hnsw.search(&query_vec, k, *ef_search) }; - - // Convert results + let metadata_store = self.metadata_store.lock().unwrap(); - - let indices = RArray::new(); - let distances = RArray::new(); - + + let indices = ruby.ary_new(); + let distances = ruby.ary_new(); + for neighbor in neighbors { if let Some(metadata) = metadata_store.get(&neighbor.d_id) { - indices.push(RString::new(&metadata.label))?; - distances.push(Float::from_f64(neighbor.distance as f64))?; + indices.push(ruby.str_new(&metadata.label))?; + distances.push(ruby.float_from_f64(neighbor.distance as f64))?; } } - + if include_distances { - let result = RArray::new(); + let result = ruby.ary_new(); result.push(indices)?; result.push(distances)?; Ok(result.as_value()) @@ -329,114 +313,107 @@ impl HnswIndex { Ok(indices.as_value()) } } - - // Search with metadata included + pub fn search_with_metadata(&self, query: RArray, kwargs: RHash) -> Result { - let k: usize = if let Some(v) = kwargs.delete(Symbol::new("k"))? { + let ruby = Ruby::get().unwrap(); + + let k: usize = if let Some(v) = kwargs.delete(ruby.to_symbol("k"))? { TryConvert::try_convert(v).unwrap_or(10) } else { 10 }; - - // Parse query vector - let query_vec = parse_vector(query, self.dim)?; - - // Perform search + + let query_vec = parse_vector(&ruby, query, self.dim)?; + let neighbors = { let hnsw = self.hnsw.lock().unwrap(); let ef_search = self.ef_search.lock().unwrap(); hnsw.search(&query_vec, k, *ef_search) }; - - // Build results with metadata + let metadata_store = self.metadata_store.lock().unwrap(); - let results = RArray::new(); - + let results = ruby.ary_new(); + for neighbor in neighbors { if let Some(item_metadata) = metadata_store.get(&neighbor.d_id) { - let result = RHash::new(); - result.aset(Symbol::new("label"), RString::new(&item_metadata.label))?; - result.aset(Symbol::new("distance"), Float::from_f64(neighbor.distance as f64))?; - - let meta_hash = RHash::new(); + let result = ruby.hash_new(); + result.aset(ruby.to_symbol("label"), ruby.str_new(&item_metadata.label))?; + result.aset(ruby.to_symbol("distance"), ruby.float_from_f64(neighbor.distance as f64))?; + + let meta_hash = ruby.hash_new(); if let Some(ref meta) = item_metadata.metadata { for (key, value) in meta { - meta_hash.aset(RString::new(key), RString::new(value))?; + meta_hash.aset(ruby.str_new(key), ruby.str_new(value))?; } } - result.aset(Symbol::new("metadata"), meta_hash)?; - + result.aset(ruby.to_symbol("metadata"), meta_hash)?; + results.push(result)?; } } - + Ok(results.as_value()) } - - // Get current size of the index + pub fn size(&self) -> Result { let metadata_store = self.metadata_store.lock().unwrap(); Ok(metadata_store.len()) } - - // Check if index is empty + pub fn empty(&self) -> Result { Ok(self.size()? == 0) } - - // Set the ef parameter for search + pub fn set_ef(&self, ef: usize) -> Result { + let ruby = Ruby::get().unwrap(); let mut ef_search = self.ef_search.lock().unwrap(); *ef_search = ef; - Ok(value::qnil().as_value()) + Ok(ruby.qnil().as_value()) } - - // Get configuration + pub fn config(&self) -> Result { - let config = RHash::new(); - config.aset(Symbol::new("dim"), Integer::from_i64(self.dim as i64))?; - + let ruby = Ruby::get().unwrap(); + let config = ruby.hash_new(); + config.aset(ruby.to_symbol("dim"), ruby.integer_from_i64(self.dim as i64))?; + let space_str = match self.space { DistanceType::Euclidean => "euclidean", DistanceType::Cosine => "cosine", DistanceType::InnerProduct => "inner_product", }; - config.aset(Symbol::new("space"), RString::new(space_str))?; - + config.aset(ruby.to_symbol("space"), ruby.str_new(space_str))?; + let ef_search = self.ef_search.lock().unwrap(); - config.aset(Symbol::new("ef"), Integer::from_i64(*ef_search as i64))?; - config.aset(Symbol::new("size"), Integer::from_i64(self.size()? as i64))?; - + config.aset(ruby.to_symbol("ef"), ruby.integer_from_i64(*ef_search as i64))?; + config.aset(ruby.to_symbol("size"), ruby.integer_from_i64(self.size()? as i64))?; + Ok(config) } - - // Get statistics about the index + pub fn stats(&self) -> Result { - let stats = RHash::new(); - - stats.aset(Symbol::new("size"), Integer::from_i64(self.size()? as i64))?; - stats.aset(Symbol::new("dim"), Integer::from_i64(self.dim as i64))?; - + let ruby = Ruby::get().unwrap(); + let stats = ruby.hash_new(); + + stats.aset(ruby.to_symbol("size"), ruby.integer_from_i64(self.size()? as i64))?; + stats.aset(ruby.to_symbol("dim"), ruby.integer_from_i64(self.dim as i64))?; + let ef_search = self.ef_search.lock().unwrap(); - stats.aset(Symbol::new("ef_search"), Integer::from_i64(*ef_search as i64))?; - - // TODO: Add more statistics from HNSW structure - + stats.aset(ruby.to_symbol("ef_search"), ruby.integer_from_i64(*ef_search as i64))?; + Ok(stats) } - - // Load index from file (class method) + pub fn load(path: RString) -> Result { + let ruby = Ruby::get().unwrap(); let path_str = path.to_string()?; - - // Load metadata first to get dimensions and space + let metadata_path = format!("{}.metadata", path_str); let metadata_file = File::open(&metadata_path) - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to open metadata file: {}", e)))?; - + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to open metadata file: {}", e)))?; + let ( _metadata_store, - _label_to_id, + _label_to_id, _current_id, _dim, _space_str, @@ -445,25 +422,19 @@ impl HnswIndex { HashMap, usize, usize, - String, // Changed from &str to String for deserialization + String, ) = bincode::deserialize_from(metadata_file) - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to load metadata: {}", e)))?; - - // Load HNSW structure + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to load metadata: {}", e)))?; + let hnsw_dir = format!("{}_hnsw_data", path_str); let hnsw_path = std::path::Path::new(&hnsw_dir); - - // Create HnswIo and leak it to get 'static lifetime - // This is a memory leak, but necessary due to hnsw_rs lifetime constraints - // The memory will never be freed until the program exits + let hnswio = Box::new(HnswIo::new(hnsw_path, "hnsw")); let hnswio_static: &'static mut HnswIo = Box::leak(hnswio); - - // Now we can load the HNSW with 'static lifetime + let hnsw: Hnsw<'static, f32, DistL2> = hnswio_static.load_hnsw() - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to load HNSW index: {}", e)))?; - - // Use the loaded metadata + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to load HNSW index: {}", e)))?; + let metadata_store = _metadata_store; let label_to_id = _label_to_id; let current_id = _current_id; @@ -472,12 +443,11 @@ impl HnswIndex { "euclidean" => DistanceType::Euclidean, "cosine" => DistanceType::Cosine, "inner_product" => DistanceType::InnerProduct, - _ => return Err(Error::new(exception::runtime_error(), "Unknown distance type in saved file")), + _ => return Err(Error::new(ruby.exception_runtime_error(), "Unknown distance type in saved file")), }; - - // Use default ef_construction as ef_search + let ef_search = 200; - + Ok(Self { hnsw: Arc::new(Mutex::new(hnsw)), dim, @@ -488,30 +458,27 @@ impl HnswIndex { ef_search: Arc::new(Mutex::new(ef_search)), }) } - - // Save index to file + pub fn save(&self, path: RString) -> Result { + let ruby = Ruby::get().unwrap(); let path_str = path.to_string()?; - - // Create directory for HNSW structure + let hnsw_dir = format!("{}_hnsw_data", path_str); std::fs::create_dir_all(&hnsw_dir) - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to create directory: {}", e)))?; - - // Save HNSW structure + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to create directory: {}", e)))?; + { let hnsw = self.hnsw.lock().unwrap(); hnsw.file_dump(&std::path::Path::new(&hnsw_dir), "hnsw") - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to save HNSW: {}", e)))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to save HNSW: {}", e)))?; } - - // Save metadata + let metadata_path = format!("{}.metadata", path_str); { let metadata_store = self.metadata_store.lock().unwrap(); let label_to_id = self.label_to_id.lock().unwrap(); let current_id = self.current_id.lock().unwrap(); - + let metadata_data = ( &*metadata_store, &*label_to_id, @@ -523,56 +490,55 @@ impl HnswIndex { DistanceType::InnerProduct => "inner_product", }, ); - + let file = File::create(&metadata_path) - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to create metadata file: {}", e)))?; - + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to create metadata file: {}", e)))?; + bincode::serialize_into(file, &metadata_data) - .map_err(|e| Error::new(exception::runtime_error(), format!("Failed to save metadata: {}", e)))?; + .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to save metadata: {}", e)))?; } - - Ok(value::qnil().as_value()) + + Ok(ruby.qnil().as_value()) } } // Helper function to parse a Ruby array into a Vec -fn parse_vector(array: RArray, expected_dim: usize) -> Result, Error> { +fn parse_vector(ruby: &Ruby, array: RArray, expected_dim: usize) -> Result, Error> { let len = array.len(); if len != expected_dim { return Err(Error::new( - exception::arg_error(), + ruby.exception_arg_error(), format!("Vector dimension mismatch: expected {}, got {}", expected_dim, len) )); } - + let mut vec = Vec::with_capacity(len); - for item in array.each() { - let value: f64 = TryConvert::try_convert(item?) - .map_err(|_| Error::new(exception::type_error(), "Vector elements must be numeric"))?; + for i in 0..len { + let value: f64 = array.entry(i as isize)?; vec.push(value as f32); } - + Ok(vec) } // Helper function to parse metadata -fn parse_metadata(value: Value) -> Result, Error> { +fn parse_metadata(ruby: &Ruby, value: Value) -> Result, Error> { let hash: RHash = TryConvert::try_convert(value) - .map_err(|_| Error::new(exception::type_error(), "Metadata must be a hash"))?; - + .map_err(|_| Error::new(ruby.exception_type_error(), "Metadata must be a hash"))?; + let mut metadata = HashMap::new(); - + hash.foreach(|key: Value, value: Value| { - // Handle both string and symbol keys + let ruby = Ruby::get().unwrap(); + let key_str = if let Ok(s) = String::try_convert(key) { s } else if let Ok(sym) = Symbol::try_convert(key) { sym.name()?.to_string() } else { - return Err(Error::new(exception::type_error(), "Metadata keys must be strings or symbols")); + return Err(Error::new(ruby.exception_type_error(), "Metadata keys must be strings or symbols")); }; - - // Convert value to string (handle various Ruby types) + let value_str = if let Ok(s) = String::try_convert(value) { s } else if let Ok(i) = Integer::try_convert(value) { @@ -580,22 +546,22 @@ fn parse_metadata(value: Value) -> Result, Error> { } else if let Ok(f) = Float::try_convert(value) { f.to_f64().to_string() } else { - // Fallback: use Ruby's to_s method let to_s_method = value.funcall::<_, _, RString>("to_s", ())?; to_s_method.to_string()? }; - + metadata.insert(key_str, value_str); Ok(ForEach::Continue) })?; - + Ok(metadata) } // Initialize the HNSW module pub fn init(parent: &magnus::RModule) -> Result<(), Error> { - let class = parent.define_class("HNSW", class::object())?; - + let ruby = Ruby::get().unwrap(); + let class = parent.define_class("HNSW", ruby.class_object())?; + class.define_singleton_method("new", function!(HnswIndex::new, 1))?; class.define_singleton_method("load", function!(HnswIndex::load, 1))?; class.define_method("add_item", method!(HnswIndex::add_item, 2))?; @@ -608,6 +574,6 @@ pub fn init(parent: &magnus::RModule) -> Result<(), Error> { class.define_method("config", method!(HnswIndex::config, 0))?; class.define_method("stats", method!(HnswIndex::stats, 0))?; class.define_method("save", method!(HnswIndex::save, 1))?; - + Ok(()) -} \ No newline at end of file +} diff --git a/ext/clusterkit/src/lib.rs b/ext/clusterkit/src/lib.rs index 75badab..4a8f4d8 100644 --- a/ext/clusterkit/src/lib.rs +++ b/ext/clusterkit/src/lib.rs @@ -1,4 +1,4 @@ -use magnus::{define_module, Error}; +use magnus::{Error, Ruby}; mod embedder; mod svd; @@ -10,15 +10,15 @@ mod hnsw; mod tests; #[magnus::init] -fn init() -> Result<(), Error> { - let module = define_module("ClusterKit")?; - +fn init(ruby: &Ruby) -> Result<(), Error> { + let module = ruby.define_module("ClusterKit")?; + // Initialize submodules embedder::init(&module)?; svd::init(&module)?; utils::init(&module)?; clustering::init(&module)?; hnsw::init(&module)?; - + Ok(()) } \ No newline at end of file diff --git a/ext/clusterkit/src/svd.rs b/ext/clusterkit/src/svd.rs index c581414..e5c7efd 100644 --- a/ext/clusterkit/src/svd.rs +++ b/ext/clusterkit/src/svd.rs @@ -1,91 +1,89 @@ -use magnus::{function, prelude::*, Error, Value, RArray}; +use magnus::{function, prelude::*, Error, Value, RArray, Ruby}; use annembed::tools::svdapprox::{SvdApprox, RangeApproxMode, RangeRank, MatRepr}; use crate::utils::ruby_array_to_ndarray2; pub fn init(parent: &magnus::RModule) -> Result<(), Error> { let svd_module = parent.define_module("SVD")?; - + svd_module.define_singleton_method( "randomized_svd_rust", function!(randomized_svd, 3), )?; - + Ok(()) } fn randomized_svd(matrix: Value, k: usize, n_iter: usize) -> Result { + let ruby = Ruby::get().unwrap(); + // Convert Ruby array to ndarray using shared helper let matrix_data = ruby_array_to_ndarray2(matrix)?; let (n_rows, n_cols) = matrix_data.dim(); - + if k > n_rows.min(n_cols) { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), format!("k ({}) cannot be larger than min(rows, cols) = {}", k, n_rows.min(n_cols)), )); } - + // Create MatRepr for the full matrix let mat_repr = MatRepr::from_array2(matrix_data.clone()); - + // Create SvdApprox instance let mut svd_approx = SvdApprox::new(&mat_repr); - + // Set up parameters for randomized SVD - // Use RANK mode to specify the desired rank let params = RangeApproxMode::RANK(RangeRank::new(k, n_iter)); - + // Perform SVD let svd_result = svd_approx.direct_svd(params) - .map_err(|e| Error::new(magnus::exception::runtime_error(), e))?; - - // Extract U, S, V from the result - they are optional fields + .map_err(|e| Error::new(ruby.exception_runtime_error(), e))?; + + // Extract U, S, V from the result let u_matrix = svd_result.u.ok_or_else(|| { - Error::new(magnus::exception::runtime_error(), "No U matrix in SVD result") + Error::new(ruby.exception_runtime_error(), "No U matrix in SVD result") })?; - + let s_values = svd_result.s.ok_or_else(|| { - Error::new(magnus::exception::runtime_error(), "No S values in SVD result") + Error::new(ruby.exception_runtime_error(), "No S values in SVD result") })?; - + let vt_matrix = svd_result.vt.ok_or_else(|| { - Error::new(magnus::exception::runtime_error(), "No V^T matrix in SVD result") + Error::new(ruby.exception_runtime_error(), "No V^T matrix in SVD result") })?; - + // Convert results to Ruby arrays - // U matrix - convert ndarray to Ruby nested array - let u_ruby = RArray::new(); + let u_ruby = ruby.ary_new(); let u_shape = u_matrix.shape(); for i in 0..u_shape[0] { - let row = RArray::new(); + let row = ruby.ary_new(); for j in 0..u_shape[1] { row.push(u_matrix[[i, j]])?; } u_ruby.push(row)?; } - - // S values - convert to Ruby array - let s_ruby = RArray::new(); + + let s_ruby = ruby.ary_new(); for val in s_values.iter() { s_ruby.push(*val)?; } - - // V matrix (note: we have V^T, so we need to transpose) - let v_ruby = RArray::new(); + + let v_ruby = ruby.ary_new(); let vt_shape = vt_matrix.shape(); for i in 0..vt_shape[0] { - let row = RArray::new(); + let row = ruby.ary_new(); for j in 0..vt_shape[1] { row.push(vt_matrix[[i, j]])?; } v_ruby.push(row)?; } - + // Return [U, S, V^T] as a Ruby array - let result = RArray::new(); + let result = ruby.ary_new(); result.push(u_ruby)?; result.push(s_ruby)?; result.push(v_ruby)?; - + Ok(result) -} \ No newline at end of file +} diff --git a/ext/clusterkit/src/utils.rs b/ext/clusterkit/src/utils.rs index 57094c8..f1a04d2 100644 --- a/ext/clusterkit/src/utils.rs +++ b/ext/clusterkit/src/utils.rs @@ -1,34 +1,34 @@ -use magnus::{function, prelude::*, Error, Value, RArray, TryConvert, Float, Integer}; +use magnus::{function, prelude::*, Error, Value, RArray, TryConvert, Float, Integer, Ruby}; use ndarray::Array2; pub fn init(parent: &magnus::RModule) -> Result<(), Error> { let utils_module = parent.define_module("Utils")?; - + utils_module.define_singleton_method( "estimate_intrinsic_dimension_rust", function!(estimate_intrinsic_dimension, 2), )?; - + utils_module.define_singleton_method( "estimate_hubness_rust", function!(estimate_hubness, 1), )?; - + Ok(()) } fn estimate_intrinsic_dimension(_data: Value, _k_neighbors: usize) -> Result { - // TODO: Implement using annembed + let ruby = Ruby::get().unwrap(); Err(Error::new( - magnus::exception::not_imp_error(), + ruby.exception_not_imp_error(), "Dimension estimation not implemented yet", )) } fn estimate_hubness(_data: Value) -> Result { - // TODO: Implement using annembed + let ruby = Ruby::get().unwrap(); Err(Error::new( - magnus::exception::not_imp_error(), + ruby.exception_not_imp_error(), "Hubness estimation not implemented yet", )) } @@ -36,12 +36,13 @@ fn estimate_hubness(_data: Value) -> Result { /// Convert Ruby 2D array to ndarray Array2 /// Handles validation and provides consistent error messages pub fn ruby_array_to_ndarray2(data: Value) -> Result, Error> { + let ruby = Ruby::get().unwrap(); let rarray: RArray = TryConvert::try_convert(data)?; let n_samples = rarray.len(); if n_samples == 0 { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), "Data cannot be empty", )); } @@ -52,7 +53,7 @@ pub fn ruby_array_to_ndarray2(data: Value) -> Result, Error> { if n_features == 0 { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), "Data rows cannot be empty", )); } @@ -61,11 +62,11 @@ pub fn ruby_array_to_ndarray2(data: Value) -> Result, Error> { let mut data_array = Array2::::zeros((n_samples, n_features)); for i in 0..n_samples { let row: RArray = rarray.entry(i as isize)?; - + // Validate row length consistency if row.len() != n_features { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), format!("Row {} has {} elements, expected {}", i, row.len(), n_features), )); } @@ -80,14 +81,15 @@ pub fn ruby_array_to_ndarray2(data: Value) -> Result, Error> { } /// Convert Ruby 2D array to Vec> -/// Handles validation and provides consistent error messages +/// Handles validation and provides consistent error messages pub fn ruby_array_to_vec_vec_f64(data: Value) -> Result>, Error> { + let ruby = Ruby::get().unwrap(); let rarray: RArray = TryConvert::try_convert(data)?; let n_samples = rarray.len(); if n_samples == 0 { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), "Data cannot be empty", )); } @@ -98,13 +100,13 @@ pub fn ruby_array_to_vec_vec_f64(data: Value) -> Result>, Error> { for i in 0..n_samples { let row: RArray = rarray.entry(i as isize)?; let n_features = row.len(); - + // Check row length consistency match expected_features { Some(expected) => { if n_features != expected { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), format!("Row {} has {} elements, expected {}", i, n_features, expected), )); } @@ -126,12 +128,13 @@ pub fn ruby_array_to_vec_vec_f64(data: Value) -> Result>, Error> { /// Convert Ruby 2D array to Vec> /// For algorithms that require f32 precision (like UMAP) pub fn ruby_array_to_vec_vec_f32(data: Value) -> Result>, Error> { + let ruby = Ruby::get().unwrap(); let rarray: RArray = TryConvert::try_convert(data)?; let array_len = rarray.len(); if array_len == 0 { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), "Input data cannot be empty", )); } @@ -142,7 +145,7 @@ pub fn ruby_array_to_vec_vec_f32(data: Value) -> Result>, Error> { let row = rarray.entry::(i as isize)?; let row_array = RArray::try_convert(row).map_err(|_| { Error::new( - magnus::exception::type_error(), + ruby.exception_type_error(), "Expected array of arrays (2D array)", ) })?; @@ -158,7 +161,7 @@ pub fn ruby_array_to_vec_vec_f32(data: Value) -> Result>, Error> { i.to_i64()? as f32 } else { return Err(Error::new( - magnus::exception::type_error(), + ruby.exception_type_error(), "All values must be numeric", )); }; @@ -168,7 +171,7 @@ pub fn ruby_array_to_vec_vec_f32(data: Value) -> Result>, Error> { // Validate row length consistency if !rust_data.is_empty() && rust_row.len() != rust_data[0].len() { return Err(Error::new( - magnus::exception::arg_error(), + ruby.exception_arg_error(), "All rows must have the same length", )); } @@ -177,4 +180,4 @@ pub fn ruby_array_to_vec_vec_f32(data: Value) -> Result>, Error> { } Ok(rust_data) -} \ No newline at end of file +} diff --git a/lib/clusterkit.rb b/lib/clusterkit.rb index c1c5cfd..d574c5c 100644 --- a/lib/clusterkit.rb +++ b/lib/clusterkit.rb @@ -1,7 +1,7 @@ # frozen_string_literal: true require_relative "clusterkit/version" -require_relative "clusterkit/clusterkit" +require "clusterkit/clusterkit" require_relative "clusterkit/configuration" # Main module for ClusterKit gem diff --git a/spec/require_spec.rb b/spec/require_spec.rb new file mode 100644 index 0000000..158b649 --- /dev/null +++ b/spec/require_spec.rb @@ -0,0 +1,45 @@ +require "spec_helper" + +# These specs verify that the native extension loads correctly via $LOAD_PATH +# resolution (using `require`) rather than relative path resolution +# (using `require_relative`). +# +# Background: +# RubyGems installs native extensions into a separate extensions directory +# (e.g., ~/.gem/ruby/3.4.0/extensions/...) and adds that directory to +# $LOAD_PATH. Using `require_relative` bypasses $LOAD_PATH and looks only +# in the gem's lib/ directory, where the compiled .so/.bundle file does not +# exist. Using `require` resolves via $LOAD_PATH and finds the extension +# in the correct location. + +RSpec.describe "Native extension loading" do + it "loads the ClusterKit module successfully" do + expect(defined?(ClusterKit)).to eq("constant") + end + + it "makes ClusterKit::Error available" do + expect(defined?(ClusterKit::Error)).to eq("constant") + end + + it "makes ClusterKit::DimensionError available" do + expect(defined?(ClusterKit::DimensionError)).to eq("constant") + end + + it "makes ClusterKit::DataError available" do + expect(defined?(ClusterKit::DataError)).to eq("constant") + end + + it "can instantiate a UMAP object (proves native extension is functional)" do + umap = ClusterKit::Dimensionality::UMAP.new(n_components: 2) + expect(umap).to be_a(ClusterKit::Dimensionality::UMAP) + end + + it "loads clusterkit/clusterkit via require (not require_relative)" do + # Verify that lib/clusterkit.rb uses `require` for the native extension. + # This is critical because RubyGems places compiled extensions in the + # extensions directory, not in the gem's lib/ directory. + clusterkit_rb = File.read(File.expand_path("../lib/clusterkit.rb", __dir__)) + expect(clusterkit_rb).to include('require "clusterkit/clusterkit"') + expect(clusterkit_rb).not_to include('require_relative "clusterkit/clusterkit"') + end +end