diff --git a/src/llm.rs b/src/llm.rs index 89be934..089d3a2 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -23,7 +23,8 @@ pub struct LLM { impl Default for LLM { fn default() -> Self { - let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); + let num_heads = 8; // Default to 8 attention heads + let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); let output_projection = OutputProjection::new(EMBEDDING_DIM, Vocab::default_words().len()); Self { vocab: Vocab::default(), @@ -128,7 +129,7 @@ impl LLM { pub fn train(&mut self, data: Vec<&str>, epochs: usize, lr: f32) { let tokenized_data = data .iter() - .map(|input| (self.tokenize(input))) + .map(|input| self.tokenize(input)) .collect::>>(); for epoch in 0..epochs { diff --git a/src/main.rs b/src/main.rs index b08e920..dec10a4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -173,9 +173,11 @@ fn main() { let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s| s.as_str()).collect(); let vocab = Vocab::new(vocab_words_refs); - let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); - let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); - let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); + // Using 8 attention heads (EMBEDDING_DIM=128 / 8 = 16 dim per head) + let num_heads = 8; + let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len()); let embeddings = Embeddings::new(vocab.clone()); let mut llm = LLM::new(vocab, vec![ @@ -188,6 +190,11 @@ fn main() { println!("\n=== MODEL INFORMATION ==="); println!("Network architecture: {}", llm.network_description()); + println!( + "Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}, num_heads: {}", + MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM, num_heads + ); + println!("\n=== BEFORE TRAINING ==="); println!("Input: {}", string); diff --git a/src/self_attention.rs b/src/self_attention.rs index a485176..6c4f3c8 100644 --- a/src/self_attention.rs +++ b/src/self_attention.rs @@ -7,27 +7,42 @@ use std::f32; pub struct SelfAttention { pub embedding_dim: usize, + pub num_heads: usize, + pub head_dim: usize, w_q: Array2, // Weight matrices for Q, K, V w_k: Array2, w_v: Array2, + w_o: Array2, // Output projection matrix cached_input: Option>, + cached_q: Option>, + cached_k: Option>, + cached_v: Option>, + cached_attn_weights: Option>>, optimizer_w_q: Adam, optimizer_w_k: Adam, optimizer_w_v: Adam, + optimizer_w_o: Adam, } impl Default for SelfAttention { fn default() -> Self { - SelfAttention::new(EMBEDDING_DIM) + SelfAttention::new(EMBEDDING_DIM, 8) // 8 attention heads by default } } impl SelfAttention { - /// Initializes a Transformer with random Q, K, V weights - pub fn new(embedding_dim: usize) -> Self { + /// Initializes a Multi-Head Attention with random Q, K, V, O weights + /// num_heads: Number of attention heads (embedding_dim must be divisible by num_heads) + pub fn new(embedding_dim: usize, num_heads: usize) -> Self { + assert!( + embedding_dim % num_heads == 0, + "embedding_dim must be divisible by num_heads" + ); + + let head_dim = embedding_dim / num_heads; let mut rng = rand::rng(); // Xavier/He initialization: std = sqrt(2 / fan_in) let std = (2.0 / embedding_dim as f32).sqrt(); @@ -35,13 +50,21 @@ impl SelfAttention { SelfAttention { embedding_dim, + num_heads, + head_dim, w_q: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), w_k: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), w_v: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), + w_o: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)), cached_input: None, + cached_q: None, + cached_k: None, + cached_v: None, + cached_attn_weights: None, optimizer_w_q: Adam::new((embedding_dim, embedding_dim)), optimizer_w_k: Adam::new((embedding_dim, embedding_dim)), optimizer_w_v: Adam::new((embedding_dim, embedding_dim)), + optimizer_w_o: Adam::new((embedding_dim, embedding_dim)), } } @@ -50,10 +73,43 @@ impl SelfAttention { let k = input.dot(&self.w_k); // K = X * W_K let v = input.dot(&self.w_v); // V = X * W_V (q, k, v) + + } + /// Split the input tensor into multiple heads + /// Input shape: [seq_len, embedding_dim] + /// Output shape: [num_heads, seq_len, head_dim] + fn split_heads(&self, x: &Array2) -> Vec> { + let mut heads = Vec::with_capacity(self.num_heads); + + for h in 0..self.num_heads { + let start = h * self.head_dim; + let end = start + self.head_dim; + let head = x.slice(ndarray::s![.., start..end]).to_owned(); + heads.push(head); + } + + heads + } + + /// Concatenate multiple heads back together + /// Input: Vec of [seq_len, head_dim] arrays + /// Output shape: [seq_len, embedding_dim] + fn concat_heads(&self, heads: Vec>) -> Array2 { + let seq_len = heads[0].shape()[0]; + let mut result = Array2::zeros((seq_len, self.embedding_dim)); + + for (h, head) in heads.iter().enumerate() { + let start = h * self.head_dim; + let end = start + self.head_dim; + result.slice_mut(ndarray::s![.., start..end]).assign(head); + } + + result } - fn attention(&self, q: &Array2, k: &Array2, v: &Array2) -> Array2 { - let dk = (self.embedding_dim as f32).sqrt(); + /// Single-head attention computation + fn attention_head(&self, q: &Array2, k: &Array2, v: &Array2) -> (Array2, Array2) { + let dk = (self.head_dim as f32).sqrt(); let k_t = k.t(); let mut scores = q.dot(&k_t) / dk; @@ -67,7 +123,36 @@ impl SelfAttention { } let weights = self.softmax(&scores); - weights.dot(v) + let output = weights.dot(v); + (output, weights) + } + + /// Multi-head attention: applies attention independently for each head + fn multi_head_attention( + &self, + q: &Array2, + k: &Array2, + v: &Array2, + ) -> (Array2, Vec>) { + // Split into heads + let q_heads = self.split_heads(q); + let k_heads = self.split_heads(k); + let v_heads = self.split_heads(v); + + // Apply attention for each head + let mut output_heads = Vec::with_capacity(self.num_heads); + let mut attn_weights = Vec::with_capacity(self.num_heads); + + for i in 0..self.num_heads { + let (head_output, head_weights) = self.attention_head(&q_heads[i], &k_heads[i], &v_heads[i]); + output_heads.push(head_output); + attn_weights.push(head_weights); + } + + // Concatenate heads + let concat_output = self.concat_heads(output_heads); + + (concat_output, attn_weights) } fn softmax(&self, scores: &Array2) -> Array2 { @@ -123,67 +208,119 @@ impl SelfAttention { impl Layer for SelfAttention { fn layer_type(&self) -> &str { - "SelfAttention" + "MUltiHeadSelfAttention" } fn forward(&mut self, input: &Array2) -> Array2 { + // Compute Q, K, V projections + let (q, k, v) = self.compute_qkv(input); + + // Cache for backward pass self.cached_input = Some(input.clone()); - let qkv = self.compute_qkv(input); - let attention = self.attention(&qkv.0, &qkv.1, &qkv.2); - attention + input // residual connection (no LayerNorm here) + self.cached_q = Some(q.clone()); + self.cached_k = Some(k.clone()); + self.cached_v = Some(v.clone()); + + // Apply multi-head attention + let (multi_head_output, attn_weights) = self.multi_head_attention(&q, &k, &v); + self.cached_attn_weights = Some(attn_weights); + + // Apply output projection + let projected = multi_head_output.dot(&self.w_o); + + // Add residual connection + projected + input } fn backward(&mut self, grads: &Array2, lr: f32) -> Array2 { let input = self.cached_input.as_ref().unwrap(); - let q = input.dot(&self.w_q); - let k = input.dot(&self.w_k); - let v = input.dot(&self.w_v); - let dk = self.w_q.shape()[1] as f32; - let scale = dk.sqrt(); - - let mut scores = q.dot(&k.t()) / scale; - - // Apply causal masking - prevent attention to future tokens - let seq_len = scores.shape()[0]; - for i in 0..seq_len { - for j in (i + 1)..seq_len { - scores[[i, j]] = f32::NEG_INFINITY; + let q = self.cached_q.as_ref().unwrap(); + let k = self.cached_k.as_ref().unwrap(); + let v = self.cached_v.as_ref().unwrap(); + let attn_weights = self.cached_attn_weights.as_ref().unwrap(); + + // Gradient through output projection: ∂L/∂W_o + let multi_head_output = { + let v_heads = self.split_heads(v); + let mut output_heads = Vec::with_capacity(self.num_heads); + for i in 0..self.num_heads { + let output = attn_weights[i].dot(&v_heads[i]); + output_heads.push(output); } + self.concat_heads(output_heads) + }; + + let grad_w_o = multi_head_output.t().dot(grads); + let grad_multi_head_output = grads.dot(&self.w_o.t()); + + // Split gradient back into heads + let grad_output_heads = self.split_heads(&grad_multi_head_output); + + // Backward through each attention head + let q_heads = self.split_heads(q); + let k_heads = self.split_heads(k); + let v_heads = self.split_heads(v); + + let mut grad_q_heads = Vec::with_capacity(self.num_heads); + let mut grad_k_heads = Vec::with_capacity(self.num_heads); + let mut grad_v_heads = Vec::with_capacity(self.num_heads); + + for i in 0..self.num_heads { + let grad_head = &grad_output_heads[i]; + let weights = &attn_weights[i]; + let q_head = &q_heads[i]; + let k_head = &k_heads[i]; + let v_head = &v_heads[i]; + + // Gradient w.r.t. V + let grad_v_head = weights.t().dot(grad_head); + + // Gradient w.r.t. attention weights + let grad_attn_weights = grad_head.dot(&v_head.t()); + + // Gradient through softmax + let grad_scores = SelfAttention::softmax_backward(weights, &grad_attn_weights); + + // Scale factor for attention + let scale = (self.head_dim as f32).sqrt(); + + // Gradient w.r.t. Q and K + let grad_q_head = grad_scores.dot(k_head) / scale; + let grad_k_head = grad_scores.t().dot(q_head) / scale; + + grad_q_heads.push(grad_q_head); + grad_k_heads.push(grad_k_head); + grad_v_heads.push(grad_v_head); } + + // Concatenate head gradients + let grad_q = self.concat_heads(grad_q_heads); + let grad_k = self.concat_heads(grad_k_heads); + let grad_v = self.concat_heads(grad_v_heads); - let attn_weights = self.softmax(&scores); // also cached - - // Step 1: grads = ∂L/∂attn_output - let grad_attn_weights = grads.dot(&v.t()); - let grad_v = attn_weights.t().dot(grads); - - // Step 2: softmax backward - let grad_scores = SelfAttention::softmax_backward(&attn_weights, &grad_attn_weights); // [seq_len, seq_len] - - // Step 3: ∂L/∂Q and ∂L/∂K - let grad_q = grad_scores.dot(&k); - let grad_k = grad_scores.t().dot(&q); - - // Step 4: ∂L/∂W_q/W_k/W_v + // Gradients w.r.t. weight matrices let grad_w_q = input.t().dot(&grad_q); let grad_w_k = input.t().dot(&grad_k); let grad_w_v = input.t().dot(&grad_v); - - // Step 5: ∂L/∂input (gradient through attention computation) - let grad_input_attention = - grad_q.dot(&self.w_q.t()) + - grad_k.dot(&self.w_k.t()) + - grad_v.dot(&self.w_v.t()); - - // Step 6: Add gradient from residual connection - // Forward: residual = attention + input, so gradient flows directly through + + // Gradients w.r.t. weight matrices + let grad_input_attention = grad_q.dot(&self.w_q.t()) + + grad_k.dot(&self.w_k.t()) + + grad_v.dot(&self.w_v.t()); + + // Add gradient from residual connection let grad_input = grad_input_attention + grads; - - // Step 7: update weights + + // Update weights using Adam optimizer self.optimizer_w_q.step(&mut self.w_q, &grad_w_q, lr); self.optimizer_w_k.step(&mut self.w_k, &grad_w_k, lr); self.optimizer_w_v.step(&mut self.w_v, &grad_w_v, lr); + self.optimizer_w_o.step(&mut self.w_o, &grad_w_o, lr); + + grad_input + } - grad_input + fn parameters(&self) -> usize { + self.w_k.len() + self.w_q.len() + self.w_v.len() + self.w_o.len() } } \ No newline at end of file diff --git a/src/transformer.rs b/src/transformer.rs index aa1c613..cb3b8ec 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -11,9 +11,9 @@ pub struct TransformerBlock { } impl TransformerBlock { - pub fn new(embedding_dim: usize, hidden_dim: usize) -> Self { + pub fn new(embedding_dim: usize, hidden_dim: usize, num_heads: usize) -> Self { TransformerBlock { - attention: SelfAttention::new(embedding_dim), + attention: SelfAttention::new(embedding_dim,num_heads), feed_forward: FeedForward::new(embedding_dim, hidden_dim), norm1: LayerNorm::new(embedding_dim), norm2: LayerNorm::new(embedding_dim), diff --git a/tests/llm_test.rs b/tests/llm_test.rs index 3530099..7d27393 100644 --- a/tests/llm_test.rs +++ b/tests/llm_test.rs @@ -133,4 +133,42 @@ fn test_llm_integration() { llm.train(vec![ input_text ], 10, 0.01); -} \ No newline at end of file +} + + + +#[test] +fn test_llm_total_parameters() { + let vocab = Vocab::default(); + let vocab_size = vocab.encode.len(); + + // Create an LLM with actual layers to get a meaningful parameter count + let num_heads = 8; + let embeddings = Box::new(Embeddings::new(vocab.clone())); + let transformer_block = Box::new(TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads)); + let output_projection = Box::new(OutputProjection::new(EMBEDDING_DIM, vocab_size)); + + let llm = LLM::new( + vocab.clone(), + vec![embeddings, transformer_block, output_projection], + ); + + // The total parameters should be greater than 0 for a model with actual layers + let param_count = llm.total_parameters(); + assert!(param_count > 0); + + // Let's validate that this is equal to the expected total number of parameters. (based on our + // source) + let expected_embeddings_parameters = vocab_size * EMBEDDING_DIM + MAX_SEQ_LEN * EMBEDDING_DIM; + let expected_transformer_block_parameters = (2 * EMBEDDING_DIM) + // LayerNorm 1 + (4 * EMBEDDING_DIM * EMBEDDING_DIM) + // Multi-Head SelfAttention (Q, K, V, O projection matrices) + (2 * EMBEDDING_DIM) + // LayerNorm 2 + (EMBEDDING_DIM * HIDDEN_DIM + HIDDEN_DIM + HIDDEN_DIM * EMBEDDING_DIM + EMBEDDING_DIM); // FeedForward + let expected_output_projection_parameters = EMBEDDING_DIM * vocab_size + vocab_size; + assert!( + param_count + == expected_embeddings_parameters + + expected_transformer_block_parameters + + expected_output_projection_parameters + ); +} diff --git a/tests/self_attention_test.rs b/tests/self_attention_test.rs index cd08341..5aed796 100644 --- a/tests/self_attention_test.rs +++ b/tests/self_attention_test.rs @@ -2,35 +2,109 @@ use llm::{Layer, EMBEDDING_DIM}; use ndarray::Array2; use llm::self_attention::SelfAttention; -#[test] -fn test_self_attention_forward() { - // Create self-attention module - let mut self_attention = SelfAttention::new(EMBEDDING_DIM); +// #[test] +// fn test_self_attention_forward() { +// // Create self-attention module +// let mut self_attention = SelfAttention::new(EMBEDDING_DIM); + +// // Create input tensor (batch_size=1, seq_len=3, embedding_dim=EMBEDDING_DIM) +// let input = Array2::ones((3, EMBEDDING_DIM)); + +// // Test forward pass +// let output = self_attention.forward(&input); + +// // Check output shape - should be same as input +// assert_eq!(output.shape(), input.shape()); +// } + +// #[test] +// fn test_self_attention_with_different_sequence_lengths() { +// // Create self-attention module +// let mut self_attention = SelfAttention::new(EMBEDDING_DIM); +// // Test with different sequence lengths +// for seq_len in 1..5 { +// // Create input tensor +// let input = Array2::ones((seq_len, EMBEDDING_DIM)); + +// // Test forward pass +// let output = self_attention.forward(&input); + +// // Check output shape +// assert_eq!(output.shape(), [seq_len, EMBEDDING_DIM]); +// } +// } + + + +#[test] +fn test_multi_head_attention_forward() { + // Create multi-head self-attention module with 8 heads + let num_heads = 8; + let mut self_attention = SelfAttention::new(EMBEDDING_DIM, num_heads); + // Create input tensor (batch_size=1, seq_len=3, embedding_dim=EMBEDDING_DIM) let input = Array2::ones((3, EMBEDDING_DIM)); - + // Test forward pass let output = self_attention.forward(&input); - + // Check output shape - should be same as input assert_eq!(output.shape(), input.shape()); } #[test] -fn test_self_attention_with_different_sequence_lengths() { - // Create self-attention module - let mut self_attention = SelfAttention::new(EMBEDDING_DIM); - +fn test_multi_head_attention_with_different_sequence_lengths() { + // Create multi-head self-attention module with 4 heads + let num_heads = 4; + let mut self_attention = SelfAttention::new(EMBEDDING_DIM, num_heads); + // Test with different sequence lengths for seq_len in 1..5 { // Create input tensor let input = Array2::ones((seq_len, EMBEDDING_DIM)); - + // Test forward pass let output = self_attention.forward(&input); - + // Check output shape assert_eq!(output.shape(), [seq_len, EMBEDDING_DIM]); } -} \ No newline at end of file +} + +#[test] +fn test_multi_head_attention_different_head_counts() { + // Test with different numbers of heads (must divide embedding_dim evenly) + let valid_head_counts = vec![1, 2, 4, 8, 16, 32, 64, 128]; + + for num_heads in valid_head_counts { + let mut self_attention = SelfAttention::new(EMBEDDING_DIM, num_heads); + let input = Array2::ones((3, EMBEDDING_DIM)); + let output = self_attention.forward(&input); + + // Verify output shape is correct + assert_eq!(output.shape(), [3, EMBEDDING_DIM]); + + // Verify parameters are calculated correctly + // Q, K, V, O projection matrices: 4 * (EMBEDDING_DIM * EMBEDDING_DIM) + let expected_params = 4 * EMBEDDING_DIM * EMBEDDING_DIM; + assert_eq!(self_attention.parameters(), expected_params); + } +} + +#[test] +fn test_multi_head_attention_backward() { + let num_heads = 8; + let mut self_attention = SelfAttention::new(EMBEDDING_DIM, num_heads); + + // Forward pass + let input = Array2::from_elem((4, EMBEDDING_DIM), 0.5); + let _output = self_attention.forward(&input); + + // Backward pass with gradient + let grad_output = Array2::ones((4, EMBEDDING_DIM)); + let grad_input = self_attention.backward(&grad_output, 0.01); + + // Check gradient shape matches input shape + assert_eq!(grad_input.shape(), input.shape()); +} diff --git a/tests/transformer_test.rs b/tests/transformer_test.rs index 366ca59..44ae399 100644 --- a/tests/transformer_test.rs +++ b/tests/transformer_test.rs @@ -4,7 +4,9 @@ use llm::transformer::TransformerBlock; #[test] fn test_transformer_block() { - let mut transformer = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM); + let num_heads = 8; + let mut transformer = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads); + // Create a simple input tensor let input = Array2::ones((1, EMBEDDING_DIM));