diff --git a/catgrad-llm/src/utils/mod.rs b/catgrad-llm/src/utils/mod.rs index 68dfe33b..6b3917a0 100644 --- a/catgrad-llm/src/utils/mod.rs +++ b/catgrad-llm/src/utils/mod.rs @@ -2,7 +2,7 @@ use catgrad::prelude::Dtype; use hf_hub::{Repo, RepoType, api::sync::ApiBuilder}; use rayon::prelude::*; use serde::de::DeserializeOwned; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeMap, HashSet}; use std::io::Read; use std::path::{Path, PathBuf}; use tokenizers::tokenizer::Tokenizer; @@ -535,8 +535,8 @@ pub fn load_model_weights( } // Read each tensor - let mut type_map = HashMap::new(); - let mut data_map = HashMap::new(); + let mut type_map = BTreeMap::new(); + let mut data_map = BTreeMap::new(); let mut total_params = 0; for file_path in model_paths { diff --git a/catgrad/examples/hidden.rs b/catgrad/examples/hidden.rs index f7ed3f11..80da9c1d 100644 --- a/catgrad/examples/hidden.rs +++ b/catgrad/examples/hidden.rs @@ -1,7 +1,7 @@ use catgrad::prelude::ops::*; use catgrad::prelude::*; -use std::collections::HashMap; +use std::collections::BTreeMap; /// Construct, shapecheck, and interpret the `SimpleMNISTModel` using the ndarray backend. fn main() -> Result<(), Box> { @@ -167,7 +167,7 @@ fn load_param_types() -> typecheck::Parameters { use catgrad::category::core::Dtype; use catgrad::typecheck::value_types::{DtypeExpr, NatExpr, NdArrayType, ShapeExpr, TypeExpr}; - let mut map = HashMap::new(); + let mut map = BTreeMap::new(); // Layer 1: (28*28) → 100 let layer1_type = Type::Tensor(TypeExpr::NdArrayType(NdArrayType { @@ -198,9 +198,9 @@ fn load_param_types() -> typecheck::Parameters { // NOTE: you would normally create this data by reading the safetensors file! fn load_param_data(backend: &B) -> interpreter::Parameters { use catgrad::category::core::Shape; - use std::collections::HashMap; + use std::collections::BTreeMap; - let mut map = HashMap::new(); + let mut map = BTreeMap::new(); // Layer 1 weights: [784, 100] - initialize with small random-ish values let layer1_data: Vec = (0..784 * 100) diff --git a/catgrad/src/abstract_interpreter/parameters.rs b/catgrad/src/abstract_interpreter/parameters.rs index 8ad66633..85af0f0d 100644 --- a/catgrad/src/abstract_interpreter/parameters.rs +++ b/catgrad/src/abstract_interpreter/parameters.rs @@ -1,12 +1,9 @@ -//use super::backend::Backend; -//use super::types::TaggedNdArray; - use super::{Interpreter, Value}; use crate::path::Path; -use std::collections::HashMap; +use std::collections::btree_map::{BTreeMap, Keys}; #[derive(Clone, Debug)] -pub struct Parameters(pub HashMap>); +pub struct Parameters(pub BTreeMap>); // Needed so Backend doesn't have to implement Default impl Default for Parameters { @@ -15,21 +12,21 @@ impl Default for Parameters { } } -impl From>> for Parameters { - fn from(map: HashMap>) -> Self { +impl From>> for Parameters { + fn from(map: BTreeMap>) -> Self { Parameters(map) } } impl From<[(Path, Value); N]> for Parameters { fn from(arr: [(Path, Value); N]) -> Self { - Parameters(HashMap::from(arr)) + Parameters(BTreeMap::from(arr)) } } impl<'a, I: Interpreter> IntoIterator for &'a Parameters { type Item = &'a Path; - type IntoIter = std::collections::hash_map::Keys<'a, Path, Value>; + type IntoIter = Keys<'a, Path, Value>; fn into_iter(self) -> Self::IntoIter { self.0.keys() @@ -37,7 +34,7 @@ impl<'a, I: Interpreter> IntoIterator for &'a Parameters { } impl Parameters { - pub fn keys(&self) -> std::collections::hash_map::Keys<'_, Path, Value> { + pub fn keys(&self) -> Keys<'_, Path, Value> { self.0.keys() } } diff --git a/catgrad/src/path.rs b/catgrad/src/path.rs index 8beeaacf..9fdbd2ba 100644 --- a/catgrad/src/path.rs +++ b/catgrad/src/path.rs @@ -3,12 +3,12 @@ use std::fmt; use std::slice; #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct Path(Vec); // Names of definitions #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct PathComponent(String); // only [a-zA-Z_] pub fn path(components: Vec<&str>) -> Result {