Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion data/chat_training_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,26 @@
"User: Thank you. Assistant: You're welcome! I'm happy to help </s>",
"User: Can you explain that? Assistant: Of course! I'd be happy to explain in more detail </s>",
"User: I don't understand. Assistant: No problem! Let me try to explain it differently </s>",
"User: What do you think? Assistant: Based on the information available, I think it's important to consider multiple perspectives </s>"
"User: What do you think? Assistant: Based on the information available, I think it's important to consider multiple perspectives </s>",
"User: What causes earthquakes? Assistant: Earthquakes are caused by sudden movements of tectonic plates beneath Earth's surface </s>",
"User: How are clouds formed? Assistant: Clouds form when moist air rises and cools, causing water vapor to condense </s>",
"User: Why do leaves fall in autumn? Assistant: Leaves fall because trees prepare for winter by conserving energy and water </s>",
"User: How do magnets work? Assistant: Magnets work by generating a magnetic field that attracts certain materials like iron </s>",
"User: What are fossils? Assistant: Fossils are preserved remains or impressions of ancient organisms </s>",
"User: What is evaporation? Assistant: Evaporation is the process where liquid turns into vapor due to heat </s>",
"User: What is the ozone layer? Assistant: The ozone layer is a region of Earth's atmosphere that absorbs most of the sun's ultraviolet radiation </s>",
"User: Why do animals hibernate? Assistant: Animals hibernate to conserve energy during times of cold and food scarcity </s>",
"User: What is a black hole? Assistant: A black hole is a region in space where gravity is so strong that nothing can escape from it </s>",
"User: What is the water cycle? Assistant: The water cycle describes how water moves through evaporation, condensation, precipitation, and collection </s>",
"User: What is a constellation? Assistant: A constellation is a group of stars that form a recognizable pattern in the night sky </s>",
"User: What is renewable energy? Assistant: Renewable energy comes from sources that naturally replenish, such as wind, solar, and water </s>",
"User: How do ants live? Assistant: Ants live in colonies with structured roles, including workers, soldiers, and a queen </s>",
"User: What is a galaxy? Assistant: A galaxy is a massive system of stars, gas, dust, and dark matter held together by gravity </s>",
"User: Why do we see different phases of the moon? Assistant: The moon's phases result from its position relative to Earth and the sun </s>",
"User: Thanks for your help. Assistant: You're very welcome! Let me know if you have more questions </s>",
"User: That makes sense. Assistant: I'm glad to hear that! Is there anything else you'd like to explore? </s>",
"User: Interesting! Assistant: I agree—it's fascinating! Would you like to dive deeper into the topic? </s>",
"User: Could you give me an example? Assistant: Sure! Here's an example that might help clarify things </s>",
"User: That's confusing. Assistant: No worries! Let me try to explain it another way </s>",
"User: Do you know more about that? Assistant: Absolutely! Here's some more information that might help </s>"
]
22 changes: 21 additions & 1 deletion data/pretraining_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,25 @@
"Storms bring rain and strong winds </s>",
"Seasons change throughout the year </s>",
"Animals eat food to survive </s>",
"Plants need sunlight and water to grow </s>"
"Plants need sunlight and water to grow </s>",
"Deserts are dry regions with little rainfall </s>",
"The Earth revolves around the sun once a year </s>",
"Tornadoes are fast-spinning columns of air </s>",
"Rainbows form when light refracts through water droplets </s>",
"Caves are natural underground spaces </s>",
"Bees collect nectar to make honey </s>",
"Lava flows from erupting volcanoes </s>",
"Thunder follows lightning during a storm </s>",
"Leaves change color in autumn </s>",
"The ocean is salty and covers most of Earth </s>",
"Spiders spin webs to catch insects </s>",
"Dolphins are intelligent marine mammals </s>",
"Butterflies emerge from cocoons </s>",
"Seeds grow into plants when given water and sunlight </s>",
"The sky appears blue due to scattered sunlight </s>",
"Airplanes fly by generating lift with their wings </s>",
"Tides are caused by the moon's gravity </s>",
"Bamboo grows quickly and is used in construction </s>",
"Crystals form through a natural solidification process </s>",
"Caterpillars turn into butterflies through metamorphosis </s>"
]
138 changes: 96 additions & 42 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,97 @@
use std::io::Write;

use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN};
use dataset_loader::{Dataset, DatasetType};

use crate::{
embeddings::Embeddings, llm::LLM, output_projection::OutputProjection,
transformer::TransformerBlock, vocab::Vocab,
use llm::{
EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN, Dataset, DatasetType,
Embeddings, LLM, OutputProjection, TransformerBlock, Vocab,
};

mod adam;
mod dataset_loader;
mod embeddings;
mod feed_forward;
mod layer_norm;
mod llm;
mod output_projection;
mod self_attention;
mod transformer;
mod vocab;
// We use the library crate (`llm`) defined in `src/lib.rs` instead of re-declaring
// modules inside the binary. The library re-exports Dataset/DatasetType and the
// configuration constants (MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM).

fn main() {
// Mock input - test conversational format
// Mock input used to test the model before and after training
let string = String::from("User: How do mountains form?");

// Create a hash set to collect all unique tokens (words and punctuation)
let mut vocab_set = std::collections::HashSet::new();

// Add special end-of-sequence token to vocabulary
vocab_set.insert("</s>".to_string());

// === Load dataset ===
let dataset = Dataset::new(
String::from("data/pretraining_data.json"),
String::from("data/chat_training_data.json"),
DatasetType::JSON,
); // Placeholder, not used in this example

// Extract all unique words from training data to create vocabulary
let mut vocab_set = std::collections::HashSet::new();

// Process all training examples for vocabulary
// First process pre-training data
Vocab::process_text_for_vocab(&dataset.pretraining_data, &mut vocab_set);
); // This loads pretraining and chat fine-tuning data

// === Build vocabulary from pre-training data ===
for text in &dataset.pretraining_data {
for word in text.split_whitespace() {
// Split punctuation from words (e.g., "hello," → "hello" and ",")
let mut current = String::new();
for c in word.chars() {
if c.is_ascii_punctuation() {
if !current.is_empty() {
vocab_set.insert(current.clone());
current.clear();
}
vocab_set.insert(c.to_string());
} else {
current.push(c);
}
}
if !current.is_empty() {
vocab_set.insert(current);
}
}
}

// Then process chat training data
Vocab::process_text_for_vocab(&dataset.chat_training_data, &mut vocab_set);
// === Build vocabulary from chat (instruction-tuning) data ===
for row in &dataset.chat_training_data {
for word in row.split_whitespace() {
let mut current = String::new();
for c in word.chars() {
if c.is_ascii_punctuation() {
if !current.is_empty() {
vocab_set.insert(current.clone());
current.clear();
}
vocab_set.insert(c.to_string());
} else {
current.push(c);
}
}
if !current.is_empty() {
vocab_set.insert(current);
}
}
}

// Convert vocabulary set into a sorted vector for deterministic order
let mut vocab_words: Vec<String> = vocab_set.into_iter().collect();
vocab_words.sort(); // Sort for deterministic ordering
vocab_words.sort();

// Convert Vec<String> to Vec<&str> because Vocab expects string slices
let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s: &String| s.as_str()).collect();

// Build the vocabulary structure
let vocab = Vocab::new(vocab_words_refs);

// === Build Transformer-based model ===
// These represent multiple stacked Transformer layers
let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);

// Projection layer: maps hidden state to vocabulary logits
let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len());

// Embedding layer: turns tokens into dense vectors
let embeddings = Embeddings::new(vocab.clone());

// Create the full LLM by stacking components in order
let mut llm = LLM::new(
vocab,
vec![
Expand All @@ -60,19 +103,21 @@ fn main() {
],
);

// === Print model information ===
println!("\n=== MODEL INFORMATION ===");
println!("Network architecture: {}", llm.network_description());
println!(
"Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}",
MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM
);

println!("Total parameters: {}", llm.total_parameters());

// === Test before any training ===
println!("\n=== BEFORE TRAINING ===");
println!("Input: {}", string);
println!("Output: {}", llm.predict(&string));

// === Pre-training phase ===
println!("\n=== PRE-TRAINING MODEL ===");
println!(
"Pre-training on {} examples for {} epochs with learning rate {}",
Expand All @@ -81,20 +126,17 @@ fn main() {
0.0005
);

// Collect pre-training examples as slices
let pretraining_examples: Vec<&str> = dataset
.pretraining_data
.iter()
.map(|s| s.as_str())
.collect();

let chat_training_examples: Vec<&str> = dataset
.chat_training_data
.iter()
.map(|s| s.as_str())
.collect();

// Train the model on pretraining data
llm.train(pretraining_examples, 100, 0.0005);

// === Instruction tuning (fine-tuning for chat) ===
println!("\n=== INSTRUCTION TUNING ===");
println!(
"Instruction tuning on {} examples for {} epochs with learning rate {}",
Expand All @@ -103,41 +145,53 @@ fn main() {
0.0001
);

llm.train(chat_training_examples, 100, 0.0001); // Much lower learning rate for stability
// Collect chat training examples as slices
let chat_training_examples: Vec<&str> = dataset
.chat_training_data
.iter()
.map(|s| s.as_str())
.collect();

// Train again on chat data with a smaller learning rate for stability
llm.train(chat_training_examples, 100, 0.0001);

// === Test after training ===
println!("\n=== AFTER TRAINING ===");
println!("Input: {}", string);
let result = llm.predict(&string);
println!("Output: {}", result);
println!("======================\n");

// Interactive mode for user input
// === Interactive loop ===
println!("\n--- Interactive Mode ---");
println!("Type a prompt and press Enter to generate text.");
println!("Type 'exit' to quit.");

let mut input = String::new();
loop {
// Clear the input string
// Clear input buffer
input.clear();

// Prompt for user input
// Print prompt without newline
print!("\nEnter prompt: ");
std::io::stdout().flush().unwrap();

// Read user input
// Read line from stdin
std::io::stdin()
.read_line(&mut input)
.expect("Failed to read input");

// Trim whitespace and check for exit command
// Remove surrounding whitespace
let trimmed_input = input.trim();

// Check for exit command
if trimmed_input.eq_ignore_ascii_case("exit") {
println!("Exiting interactive mode.");
break;
}

// Generate prediction based on user input with "User:" prefix
// Add "User:" prefix so model understands the format
// Generate response from model
let formatted_input = format!("User: {}", trimmed_input);
let prediction = llm.predict(&formatted_input);
println!("Model output: {}", prediction);
Expand Down