Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ pub struct LLM {

impl Default for LLM {
fn default() -> Self {
let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let num_heads = 8; // Default to 8 attention heads
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this should be with with the rest of constants ex; MAX_SEQ_LEN and such

let transformer_block = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads);
let output_projection = OutputProjection::new(EMBEDDING_DIM, Vocab::default_words().len());
Self {
vocab: Vocab::default(),
Expand Down Expand Up @@ -128,7 +129,7 @@ impl LLM {
pub fn train(&mut self, data: Vec<&str>, epochs: usize, lr: f32) {
let tokenized_data = data
.iter()
.map(|input| (self.tokenize(input)))
.map(|input| self.tokenize(input))
.collect::<Vec<Vec<usize>>>();

for epoch in 0..epochs {
Expand Down
13 changes: 10 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ fn main() {
let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s| s.as_str()).collect();
let vocab = Vocab::new(vocab_words_refs);

let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
// Using 8 attention heads (EMBEDDING_DIM=128 / 8 = 16 dim per head)
let num_heads = 8;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should share this with a universal const

let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads);
let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads);
let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM, num_heads);
let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len());
let embeddings = Embeddings::new(vocab.clone());
let mut llm = LLM::new(vocab, vec![
Expand All @@ -188,6 +190,11 @@ fn main() {

println!("\n=== MODEL INFORMATION ===");
println!("Network architecture: {}", llm.network_description());
println!(
"Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}, num_heads: {}",
MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM, num_heads
);


println!("\n=== BEFORE TRAINING ===");
println!("Input: {}", string);
Expand Down
235 changes: 186 additions & 49 deletions src/self_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,64 @@ use std::f32;

pub struct SelfAttention {
pub embedding_dim: usize,
pub num_heads: usize,
pub head_dim: usize,
w_q: Array2<f32>, // Weight matrices for Q, K, V
w_k: Array2<f32>,
w_v: Array2<f32>,
w_o: Array2<f32>, // Output projection matrix

cached_input: Option<Array2<f32>>,
cached_q: Option<Array2<f32>>,
cached_k: Option<Array2<f32>>,
cached_v: Option<Array2<f32>>,
cached_attn_weights: Option<Vec<Array2<f32>>>,

optimizer_w_q: Adam,
optimizer_w_k: Adam,
optimizer_w_v: Adam,
optimizer_w_o: Adam,
}

impl Default for SelfAttention {
fn default() -> Self {
SelfAttention::new(EMBEDDING_DIM)
SelfAttention::new(EMBEDDING_DIM, 8) // 8 attention heads by default
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here!

}
}


impl SelfAttention {
/// Initializes a Transformer with random Q, K, V weights
pub fn new(embedding_dim: usize) -> Self {
/// Initializes a Multi-Head Attention with random Q, K, V, O weights
/// num_heads: Number of attention heads (embedding_dim must be divisible by num_heads)
pub fn new(embedding_dim: usize, num_heads: usize) -> Self {
assert!(
embedding_dim % num_heads == 0,
"embedding_dim must be divisible by num_heads"
);

let head_dim = embedding_dim / num_heads;
let mut rng = rand::rng();
// Xavier/He initialization: std = sqrt(2 / fan_in)
let std = (2.0 / embedding_dim as f32).sqrt();
let normal = Normal::new(0.0, std).unwrap();

SelfAttention {
embedding_dim,
num_heads,
head_dim,
w_q: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)),
w_k: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)),
w_v: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)),
w_o: Array2::from_shape_fn((embedding_dim, embedding_dim), |_| normal.sample(&mut rng)),
cached_input: None,
cached_q: None,
cached_k: None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And caching!! Very cool

cached_v: None,
cached_attn_weights: None,
optimizer_w_q: Adam::new((embedding_dim, embedding_dim)),
optimizer_w_k: Adam::new((embedding_dim, embedding_dim)),
optimizer_w_v: Adam::new((embedding_dim, embedding_dim)),
optimizer_w_o: Adam::new((embedding_dim, embedding_dim)),
}
}

Expand All @@ -50,10 +73,43 @@ impl SelfAttention {
let k = input.dot(&self.w_k); // K = X * W_K
let v = input.dot(&self.w_v); // V = X * W_V
(q, k, v)

}
/// Split the input tensor into multiple heads
/// Input shape: [seq_len, embedding_dim]
/// Output shape: [num_heads, seq_len, head_dim]
fn split_heads(&self, x: &Array2<f32>) -> Vec<Array2<f32>> {
let mut heads = Vec::with_capacity(self.num_heads);

for h in 0..self.num_heads {
let start = h * self.head_dim;
let end = start + self.head_dim;
let head = x.slice(ndarray::s![.., start..end]).to_owned();
heads.push(head);
}

heads
}

/// Concatenate multiple heads back together
/// Input: Vec of [seq_len, head_dim] arrays
/// Output shape: [seq_len, embedding_dim]
fn concat_heads(&self, heads: Vec<Array2<f32>>) -> Array2<f32> {
let seq_len = heads[0].shape()[0];
let mut result = Array2::zeros((seq_len, self.embedding_dim));

for (h, head) in heads.iter().enumerate() {
let start = h * self.head_dim;
let end = start + self.head_dim;
result.slice_mut(ndarray::s![.., start..end]).assign(head);
}

result
}

fn attention(&self, q: &Array2<f32>, k: &Array2<f32>, v: &Array2<f32>) -> Array2<f32> {
let dk = (self.embedding_dim as f32).sqrt();
/// Single-head attention computation
fn attention_head(&self, q: &Array2<f32>, k: &Array2<f32>, v: &Array2<f32>) -> (Array2<f32>, Array2<f32>) {
let dk = (self.head_dim as f32).sqrt();

let k_t = k.t();
let mut scores = q.dot(&k_t) / dk;
Expand All @@ -67,7 +123,36 @@ impl SelfAttention {
}

let weights = self.softmax(&scores);
weights.dot(v)
let output = weights.dot(v);
(output, weights)
}

/// Multi-head attention: applies attention independently for each head
fn multi_head_attention(
&self,
q: &Array2<f32>,
k: &Array2<f32>,
v: &Array2<f32>,
) -> (Array2<f32>, Vec<Array2<f32>>) {
// Split into heads
let q_heads = self.split_heads(q);
let k_heads = self.split_heads(k);
let v_heads = self.split_heads(v);

// Apply attention for each head
let mut output_heads = Vec::with_capacity(self.num_heads);
let mut attn_weights = Vec::with_capacity(self.num_heads);

for i in 0..self.num_heads {
let (head_output, head_weights) = self.attention_head(&q_heads[i], &k_heads[i], &v_heads[i]);
output_heads.push(head_output);
attn_weights.push(head_weights);
}

// Concatenate heads
let concat_output = self.concat_heads(output_heads);

(concat_output, attn_weights)
}

fn softmax(&self, scores: &Array2<f32>) -> Array2<f32> {
Expand Down Expand Up @@ -123,67 +208,119 @@ impl SelfAttention {

impl Layer for SelfAttention {
fn layer_type(&self) -> &str {
"SelfAttention"
"MUltiHeadSelfAttention"
}

fn forward(&mut self, input: &Array2<f32>) -> Array2<f32> {
// Compute Q, K, V projections
let (q, k, v) = self.compute_qkv(input);

// Cache for backward pass
self.cached_input = Some(input.clone());
let qkv = self.compute_qkv(input);
let attention = self.attention(&qkv.0, &qkv.1, &qkv.2);
attention + input // residual connection (no LayerNorm here)
self.cached_q = Some(q.clone());
self.cached_k = Some(k.clone());
self.cached_v = Some(v.clone());

// Apply multi-head attention
let (multi_head_output, attn_weights) = self.multi_head_attention(&q, &k, &v);
self.cached_attn_weights = Some(attn_weights);

// Apply output projection
let projected = multi_head_output.dot(&self.w_o);

// Add residual connection
projected + input
}

fn backward(&mut self, grads: &Array2<f32>, lr: f32) -> Array2<f32> {
let input = self.cached_input.as_ref().unwrap();
let q = input.dot(&self.w_q);
let k = input.dot(&self.w_k);
let v = input.dot(&self.w_v);
let dk = self.w_q.shape()[1] as f32;
let scale = dk.sqrt();

let mut scores = q.dot(&k.t()) / scale;

// Apply causal masking - prevent attention to future tokens
let seq_len = scores.shape()[0];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
scores[[i, j]] = f32::NEG_INFINITY;
let q = self.cached_q.as_ref().unwrap();
let k = self.cached_k.as_ref().unwrap();
let v = self.cached_v.as_ref().unwrap();
let attn_weights = self.cached_attn_weights.as_ref().unwrap();

// Gradient through output projection: ∂L/∂W_o
let multi_head_output = {
let v_heads = self.split_heads(v);
let mut output_heads = Vec::with_capacity(self.num_heads);
for i in 0..self.num_heads {
let output = attn_weights[i].dot(&v_heads[i]);
output_heads.push(output);
}
self.concat_heads(output_heads)
};

let grad_w_o = multi_head_output.t().dot(grads);
let grad_multi_head_output = grads.dot(&self.w_o.t());

// Split gradient back into heads
let grad_output_heads = self.split_heads(&grad_multi_head_output);

// Backward through each attention head
let q_heads = self.split_heads(q);
let k_heads = self.split_heads(k);
let v_heads = self.split_heads(v);

let mut grad_q_heads = Vec::with_capacity(self.num_heads);
let mut grad_k_heads = Vec::with_capacity(self.num_heads);
let mut grad_v_heads = Vec::with_capacity(self.num_heads);

for i in 0..self.num_heads {
let grad_head = &grad_output_heads[i];
let weights = &attn_weights[i];
let q_head = &q_heads[i];
let k_head = &k_heads[i];
let v_head = &v_heads[i];

// Gradient w.r.t. V
let grad_v_head = weights.t().dot(grad_head);

// Gradient w.r.t. attention weights
let grad_attn_weights = grad_head.dot(&v_head.t());

// Gradient through softmax
let grad_scores = SelfAttention::softmax_backward(weights, &grad_attn_weights);

// Scale factor for attention
let scale = (self.head_dim as f32).sqrt();

// Gradient w.r.t. Q and K
let grad_q_head = grad_scores.dot(k_head) / scale;
let grad_k_head = grad_scores.t().dot(q_head) / scale;

grad_q_heads.push(grad_q_head);
grad_k_heads.push(grad_k_head);
grad_v_heads.push(grad_v_head);
}

// Concatenate head gradients
let grad_q = self.concat_heads(grad_q_heads);
let grad_k = self.concat_heads(grad_k_heads);
let grad_v = self.concat_heads(grad_v_heads);

let attn_weights = self.softmax(&scores); // also cached

// Step 1: grads = ∂L/∂attn_output
let grad_attn_weights = grads.dot(&v.t());
let grad_v = attn_weights.t().dot(grads);

// Step 2: softmax backward
let grad_scores = SelfAttention::softmax_backward(&attn_weights, &grad_attn_weights); // [seq_len, seq_len]

// Step 3: ∂L/∂Q and ∂L/∂K
let grad_q = grad_scores.dot(&k);
let grad_k = grad_scores.t().dot(&q);

// Step 4: ∂L/∂W_q/W_k/W_v
// Gradients w.r.t. weight matrices
let grad_w_q = input.t().dot(&grad_q);
let grad_w_k = input.t().dot(&grad_k);
let grad_w_v = input.t().dot(&grad_v);

// Step 5: ∂L/∂input (gradient through attention computation)
let grad_input_attention =
grad_q.dot(&self.w_q.t()) +
grad_k.dot(&self.w_k.t()) +
grad_v.dot(&self.w_v.t());

// Step 6: Add gradient from residual connection
// Forward: residual = attention + input, so gradient flows directly through

// Gradients w.r.t. weight matrices
let grad_input_attention = grad_q.dot(&self.w_q.t())
+ grad_k.dot(&self.w_k.t())
+ grad_v.dot(&self.w_v.t());

// Add gradient from residual connection
let grad_input = grad_input_attention + grads;
// Step 7: update weights

// Update weights using Adam optimizer
self.optimizer_w_q.step(&mut self.w_q, &grad_w_q, lr);
self.optimizer_w_k.step(&mut self.w_k, &grad_w_k, lr);
self.optimizer_w_v.step(&mut self.w_v, &grad_w_v, lr);
self.optimizer_w_o.step(&mut self.w_o, &grad_w_o, lr);

grad_input
}

grad_input
fn parameters(&self) -> usize {
self.w_k.len() + self.w_q.len() + self.w_v.len() + self.w_o.len()
}
}
4 changes: 2 additions & 2 deletions src/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ pub struct TransformerBlock {
}

impl TransformerBlock {
pub fn new(embedding_dim: usize, hidden_dim: usize) -> Self {
pub fn new(embedding_dim: usize, hidden_dim: usize, num_heads: usize) -> Self {
TransformerBlock {
attention: SelfAttention::new(embedding_dim),
attention: SelfAttention::new(embedding_dim,num_heads),
feed_forward: FeedForward::new(embedding_dim, hidden_dim),
norm1: LayerNorm::new(embedding_dim),
norm2: LayerNorm::new(embedding_dim),
Expand Down
Loading