Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions catgrad-llm/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use catgrad::prelude::Dtype;
use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
use rayon::prelude::*;
use serde::de::DeserializeOwned;
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeMap, HashSet};
use std::io::Read;
use std::path::{Path, PathBuf};
use tokenizers::tokenizer::Tokenizer;
Expand Down Expand Up @@ -535,8 +535,8 @@ pub fn load_model_weights<B: interpreter::Backend>(
}

// Read each tensor
let mut type_map = HashMap::new();
let mut data_map = HashMap::new();
let mut type_map = BTreeMap::new();
let mut data_map = BTreeMap::new();
let mut total_params = 0;

for file_path in model_paths {
Expand Down
8 changes: 4 additions & 4 deletions catgrad/examples/hidden.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use catgrad::prelude::ops::*;
use catgrad::prelude::*;

use std::collections::HashMap;
use std::collections::BTreeMap;

/// Construct, shapecheck, and interpret the `SimpleMNISTModel` using the ndarray backend.
fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -167,7 +167,7 @@ fn load_param_types() -> typecheck::Parameters {
use catgrad::category::core::Dtype;
use catgrad::typecheck::value_types::{DtypeExpr, NatExpr, NdArrayType, ShapeExpr, TypeExpr};

let mut map = HashMap::new();
let mut map = BTreeMap::new();

// Layer 1: (28*28) → 100
let layer1_type = Type::Tensor(TypeExpr::NdArrayType(NdArrayType {
Expand Down Expand Up @@ -198,9 +198,9 @@ fn load_param_types() -> typecheck::Parameters {
// NOTE: you would normally create this data by reading the safetensors file!
fn load_param_data<B: interpreter::Backend>(backend: &B) -> interpreter::Parameters<B> {
use catgrad::category::core::Shape;
use std::collections::HashMap;
use std::collections::BTreeMap;

let mut map = HashMap::new();
let mut map = BTreeMap::new();

// Layer 1 weights: [784, 100] - initialize with small random-ish values
let layer1_data: Vec<f32> = (0..784 * 100)
Expand Down
17 changes: 7 additions & 10 deletions catgrad/src/abstract_interpreter/parameters.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
//use super::backend::Backend;
//use super::types::TaggedNdArray;

use super::{Interpreter, Value};
use crate::path::Path;
use std::collections::HashMap;
use std::collections::btree_map::{BTreeMap, Keys};

#[derive(Clone, Debug)]
pub struct Parameters<I: Interpreter>(pub HashMap<Path, Value<I>>);
pub struct Parameters<I: Interpreter>(pub BTreeMap<Path, Value<I>>);

// Needed so Backend doesn't have to implement Default
impl<I: Interpreter> Default for Parameters<I> {
Expand All @@ -15,29 +12,29 @@ impl<I: Interpreter> Default for Parameters<I> {
}
}

impl<I: Interpreter> From<HashMap<Path, Value<I>>> for Parameters<I> {
fn from(map: HashMap<Path, Value<I>>) -> Self {
impl<I: Interpreter> From<BTreeMap<Path, Value<I>>> for Parameters<I> {
fn from(map: BTreeMap<Path, Value<I>>) -> Self {
Parameters(map)
}
}

impl<const N: usize, I: Interpreter> From<[(Path, Value<I>); N]> for Parameters<I> {
fn from(arr: [(Path, Value<I>); N]) -> Self {
Parameters(HashMap::from(arr))
Parameters(BTreeMap::from(arr))
}
}

impl<'a, I: Interpreter> IntoIterator for &'a Parameters<I> {
type Item = &'a Path;
type IntoIter = std::collections::hash_map::Keys<'a, Path, Value<I>>;
type IntoIter = Keys<'a, Path, Value<I>>;

fn into_iter(self) -> Self::IntoIter {
self.0.keys()
}
}

impl<I: Interpreter> Parameters<I> {
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, Path, Value<I>> {
pub fn keys(&self) -> Keys<'_, Path, Value<I>> {
self.0.keys()
}
}
4 changes: 2 additions & 2 deletions catgrad/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use std::fmt;
use std::slice;

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct Path(Vec<PathComponent>);

// Names of definitions
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct PathComponent(String); // only [a-zA-Z_]

pub fn path(components: Vec<&str>) -> Result<Path, InvalidPathComponent> {
Expand Down