diff --git a/data/chat_training_data.json b/data/chat_training_data.json index 4867459..17b9f83 100644 --- a/data/chat_training_data.json +++ b/data/chat_training_data.json @@ -51,5 +51,26 @@ "User: Thank you. Assistant: You're welcome! I'm happy to help ", "User: Can you explain that? Assistant: Of course! I'd be happy to explain in more detail ", "User: I don't understand. Assistant: No problem! Let me try to explain it differently ", - "User: What do you think? Assistant: Based on the information available, I think it's important to consider multiple perspectives " + "User: What do you think? Assistant: Based on the information available, I think it's important to consider multiple perspectives ", + "User: What causes earthquakes? Assistant: Earthquakes are caused by sudden movements of tectonic plates beneath Earth's surface ", + "User: How are clouds formed? Assistant: Clouds form when moist air rises and cools, causing water vapor to condense ", + "User: Why do leaves fall in autumn? Assistant: Leaves fall because trees prepare for winter by conserving energy and water ", + "User: How do magnets work? Assistant: Magnets work by generating a magnetic field that attracts certain materials like iron ", + "User: What are fossils? Assistant: Fossils are preserved remains or impressions of ancient organisms ", + "User: What is evaporation? Assistant: Evaporation is the process where liquid turns into vapor due to heat ", + "User: What is the ozone layer? Assistant: The ozone layer is a region of Earth's atmosphere that absorbs most of the sun's ultraviolet radiation ", + "User: Why do animals hibernate? Assistant: Animals hibernate to conserve energy during times of cold and food scarcity ", + "User: What is a black hole? Assistant: A black hole is a region in space where gravity is so strong that nothing can escape from it ", + "User: What is the water cycle? Assistant: The water cycle describes how water moves through evaporation, condensation, precipitation, and collection ", + "User: What is a constellation? Assistant: A constellation is a group of stars that form a recognizable pattern in the night sky ", + "User: What is renewable energy? Assistant: Renewable energy comes from sources that naturally replenish, such as wind, solar, and water ", + "User: How do ants live? Assistant: Ants live in colonies with structured roles, including workers, soldiers, and a queen ", + "User: What is a galaxy? Assistant: A galaxy is a massive system of stars, gas, dust, and dark matter held together by gravity ", + "User: Why do we see different phases of the moon? Assistant: The moon's phases result from its position relative to Earth and the sun ", + "User: Thanks for your help. Assistant: You're very welcome! Let me know if you have more questions ", + "User: That makes sense. Assistant: I'm glad to hear that! Is there anything else you'd like to explore? ", + "User: Interesting! Assistant: I agree—it's fascinating! Would you like to dive deeper into the topic? ", + "User: Could you give me an example? Assistant: Sure! Here's an example that might help clarify things ", + "User: That's confusing. Assistant: No worries! Let me try to explain it another way ", + "User: Do you know more about that? Assistant: Absolutely! Here's some more information that might help " ] diff --git a/data/pretraining_data.json b/data/pretraining_data.json index d5c8c93..c7e2aac 100644 --- a/data/pretraining_data.json +++ b/data/pretraining_data.json @@ -23,5 +23,25 @@ "Storms bring rain and strong winds ", "Seasons change throughout the year ", "Animals eat food to survive ", - "Plants need sunlight and water to grow " + "Plants need sunlight and water to grow ", + "Deserts are dry regions with little rainfall ", + "The Earth revolves around the sun once a year ", + "Tornadoes are fast-spinning columns of air ", + "Rainbows form when light refracts through water droplets ", + "Caves are natural underground spaces ", + "Bees collect nectar to make honey ", + "Lava flows from erupting volcanoes ", + "Thunder follows lightning during a storm ", + "Leaves change color in autumn ", + "The ocean is salty and covers most of Earth ", + "Spiders spin webs to catch insects ", + "Dolphins are intelligent marine mammals ", + "Butterflies emerge from cocoons ", + "Seeds grow into plants when given water and sunlight ", + "The sky appears blue due to scattered sunlight ", + "Airplanes fly by generating lift with their wings ", + "Tides are caused by the moon's gravity ", + "Bamboo grows quickly and is used in construction ", + "Crystals form through a natural solidification process ", + "Caterpillars turn into butterflies through metamorphosis " ] diff --git a/src/main.rs b/src/main.rs index 5babf3c..ffc5cec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,54 +1,97 @@ use std::io::Write; -use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN}; -use dataset_loader::{Dataset, DatasetType}; - -use crate::{ - embeddings::Embeddings, llm::LLM, output_projection::OutputProjection, - transformer::TransformerBlock, vocab::Vocab, +use llm::{ + EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN, Dataset, DatasetType, + Embeddings, LLM, OutputProjection, TransformerBlock, Vocab, }; -mod adam; -mod dataset_loader; -mod embeddings; -mod feed_forward; -mod layer_norm; -mod llm; -mod output_projection; -mod self_attention; -mod transformer; -mod vocab; +// We use the library crate (`llm`) defined in `src/lib.rs` instead of re-declaring +// modules inside the binary. The library re-exports Dataset/DatasetType and the +// configuration constants (MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM). fn main() { - // Mock input - test conversational format + // Mock input used to test the model before and after training let string = String::from("User: How do mountains form?"); + // Create a hash set to collect all unique tokens (words and punctuation) + let mut vocab_set = std::collections::HashSet::new(); + + // Add special end-of-sequence token to vocabulary + vocab_set.insert("".to_string()); + + // === Load dataset === let dataset = Dataset::new( String::from("data/pretraining_data.json"), String::from("data/chat_training_data.json"), DatasetType::JSON, - ); // Placeholder, not used in this example - - // Extract all unique words from training data to create vocabulary - let mut vocab_set = std::collections::HashSet::new(); - - // Process all training examples for vocabulary - // First process pre-training data - Vocab::process_text_for_vocab(&dataset.pretraining_data, &mut vocab_set); + ); // This loads pretraining and chat fine-tuning data + + // === Build vocabulary from pre-training data === + for text in &dataset.pretraining_data { + for word in text.split_whitespace() { + // Split punctuation from words (e.g., "hello," → "hello" and ",") + let mut current = String::new(); + for c in word.chars() { + if c.is_ascii_punctuation() { + if !current.is_empty() { + vocab_set.insert(current.clone()); + current.clear(); + } + vocab_set.insert(c.to_string()); + } else { + current.push(c); + } + } + if !current.is_empty() { + vocab_set.insert(current); + } + } + } - // Then process chat training data - Vocab::process_text_for_vocab(&dataset.chat_training_data, &mut vocab_set); + // === Build vocabulary from chat (instruction-tuning) data === + for row in &dataset.chat_training_data { + for word in row.split_whitespace() { + let mut current = String::new(); + for c in word.chars() { + if c.is_ascii_punctuation() { + if !current.is_empty() { + vocab_set.insert(current.clone()); + current.clear(); + } + vocab_set.insert(c.to_string()); + } else { + current.push(c); + } + } + if !current.is_empty() { + vocab_set.insert(current); + } + } + } + // Convert vocabulary set into a sorted vector for deterministic order let mut vocab_words: Vec = vocab_set.into_iter().collect(); - vocab_words.sort(); // Sort for deterministic ordering + vocab_words.sort(); + + // Convert Vec to Vec<&str> because Vocab expects string slices let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s: &String| s.as_str()).collect(); + + // Build the vocabulary structure let vocab = Vocab::new(vocab_words_refs); + // === Build Transformer-based model === + // These represent multiple stacked Transformer layers let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); + + // Projection layer: maps hidden state to vocabulary logits let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len()); + + // Embedding layer: turns tokens into dense vectors let embeddings = Embeddings::new(vocab.clone()); + + // Create the full LLM by stacking components in order let mut llm = LLM::new( vocab, vec![ @@ -60,19 +103,21 @@ fn main() { ], ); + // === Print model information === println!("\n=== MODEL INFORMATION ==="); println!("Network architecture: {}", llm.network_description()); println!( "Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}", MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM ); - println!("Total parameters: {}", llm.total_parameters()); + // === Test before any training === println!("\n=== BEFORE TRAINING ==="); println!("Input: {}", string); println!("Output: {}", llm.predict(&string)); + // === Pre-training phase === println!("\n=== PRE-TRAINING MODEL ==="); println!( "Pre-training on {} examples for {} epochs with learning rate {}", @@ -81,20 +126,17 @@ fn main() { 0.0005 ); + // Collect pre-training examples as slices let pretraining_examples: Vec<&str> = dataset .pretraining_data .iter() .map(|s| s.as_str()) .collect(); - let chat_training_examples: Vec<&str> = dataset - .chat_training_data - .iter() - .map(|s| s.as_str()) - .collect(); - + // Train the model on pretraining data llm.train(pretraining_examples, 100, 0.0005); + // === Instruction tuning (fine-tuning for chat) === println!("\n=== INSTRUCTION TUNING ==="); println!( "Instruction tuning on {} examples for {} epochs with learning rate {}", @@ -103,41 +145,53 @@ fn main() { 0.0001 ); - llm.train(chat_training_examples, 100, 0.0001); // Much lower learning rate for stability + // Collect chat training examples as slices + let chat_training_examples: Vec<&str> = dataset + .chat_training_data + .iter() + .map(|s| s.as_str()) + .collect(); + + // Train again on chat data with a smaller learning rate for stability + llm.train(chat_training_examples, 100, 0.0001); + // === Test after training === println!("\n=== AFTER TRAINING ==="); println!("Input: {}", string); let result = llm.predict(&string); println!("Output: {}", result); println!("======================\n"); - // Interactive mode for user input + // === Interactive loop === println!("\n--- Interactive Mode ---"); println!("Type a prompt and press Enter to generate text."); println!("Type 'exit' to quit."); let mut input = String::new(); loop { - // Clear the input string + // Clear input buffer input.clear(); - // Prompt for user input + // Print prompt without newline print!("\nEnter prompt: "); std::io::stdout().flush().unwrap(); - // Read user input + // Read line from stdin std::io::stdin() .read_line(&mut input) .expect("Failed to read input"); - // Trim whitespace and check for exit command + // Remove surrounding whitespace let trimmed_input = input.trim(); + + // Check for exit command if trimmed_input.eq_ignore_ascii_case("exit") { println!("Exiting interactive mode."); break; } - // Generate prediction based on user input with "User:" prefix + // Add "User:" prefix so model understands the format + // Generate response from model let formatted_input = format!("User: {}", trimmed_input); let prediction = llm.predict(&formatted_input); println!("Model output: {}", prediction);