From f21f65dd6d1a4140749b5a983770c4518715f0e4 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:32:04 +0000 Subject: [PATCH] feat: Add Attention Maps visualization to AI tab This commit implements the "Attention Maps" task from todo_gui.md. It creates a new `AttentionMapsTool` struct that implements the `AiTool` trait. The tool visualizes self-attention weights using `egui_plot::Plot` and `egui_plot::Points` to create a heatmap. It utilizes the existing `scaled_dot_product_attention` function from `math_explorer::ai::transformer::attention` to compute the weights based on interactively generated Query and Key matrices. The tool is registered in `math_explorer_gui/src/tabs/ai/mod.rs` and the todo list is updated. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- .../src/tabs/ai/attention_maps.rs | 133 ++++++++++++++++++ math_explorer_gui/src/tabs/ai/mod.rs | 2 + todo_gui.md | 2 +- 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 math_explorer_gui/src/tabs/ai/attention_maps.rs diff --git a/math_explorer_gui/src/tabs/ai/attention_maps.rs b/math_explorer_gui/src/tabs/ai/attention_maps.rs new file mode 100644 index 00000000..6c309121 --- /dev/null +++ b/math_explorer_gui/src/tabs/ai/attention_maps.rs @@ -0,0 +1,133 @@ +use crate::tabs::ai::AiTool; +use eframe::egui; +use egui_plot::{MarkerShape, Plot, PlotPoint, PlotPoints, Points, Text}; +use math_explorer::ai::transformer::attention::scaled_dot_product_attention; +use nalgebra::DMatrix; + +pub struct AttentionMapsTool { + seq_len: usize, + d_k: usize, + q_data: Vec, + k_data: Vec, + v_data: Vec, +} + +impl Default for AttentionMapsTool { + fn default() -> Self { + let seq_len = 5; + let d_k = 4; + let n_elements = seq_len * d_k; + Self { + seq_len, + d_k, + q_data: vec![1.0; n_elements], + k_data: vec![1.0; n_elements], + v_data: vec![1.0; n_elements], + } + } +} + +impl AiTool for AttentionMapsTool { + fn name(&self) -> &'static str { + "Attention Maps" + } + + fn show(&mut self, ctx: &egui::Context) { + egui::SidePanel::left("attention_maps_controls").show(ctx, |ui| { + ui.heading("Settings"); + + let mut changed = false; + changed |= ui.add(egui::Slider::new(&mut self.seq_len, 2..=10).text("Sequence Length")).changed(); + changed |= ui.add(egui::Slider::new(&mut self.d_k, 2..=8).text("Embedding Dim (d_k)")).changed(); + + // Resize data vectors if dimensions changed + if changed { + let expected_size = self.seq_len * self.d_k; + self.q_data.resize(expected_size, 1.0); + self.k_data.resize(expected_size, 1.0); + self.v_data.resize(expected_size, 1.0); + } + + ui.separator(); + ui.heading("Inputs (Q, K)"); + ui.label("Randomize or set values manually to see how attention weights change."); + + if ui.button("Randomize Inputs").clicked() { + use rand::Rng; + let mut rng = rand::thread_rng(); + for x in self.q_data.iter_mut() { + *x = rng.r#gen_range(-1.0..=1.0); + } + for x in self.k_data.iter_mut() { + *x = rng.r#gen_range(-1.0..=1.0); + } + } + + ui.separator(); + ui.label("Values (V) matrix is mostly irrelevant for visualizing attention weights."); + }); + + egui::CentralPanel::default().show(ctx, |ui| { + ui.heading("Self-Attention Heatmap"); + ui.label("Visualizing the attention weights: softmax(Q * K^T / sqrt(d_k))"); + ui.add_space(10.0); + + // Need to convert flat Vecs to column-major or row-major for DMatrix. + // from_vec creates column-major by default. Let's assume it's row-major + // so we'll use from_row_iterator. + let q = DMatrix::from_row_iterator(self.seq_len, self.d_k, self.q_data.clone()); + let k = DMatrix::from_row_iterator(self.seq_len, self.d_k, self.k_data.clone()); + let v = DMatrix::from_row_iterator(self.seq_len, self.d_k, self.v_data.clone()); + + let (_, attention_weights) = scaled_dot_product_attention(&q, &k, &v, None); + + // To plot heatmap, we will map weight to a color for a square marker + Plot::new("attention_heatmap") + .view_aspect(1.0) + .show_axes([false, false]) + .show_grid(false) + .allow_drag(false) + .allow_zoom(false) + .allow_scroll(false) + .show(ui, |plot_ui| { + for row in 0..self.seq_len { + for col in 0..self.seq_len { + let weight = attention_weights[(row, col)]; + // clamp weight to [0, 1] just in case + let weight = weight.clamp(0.0, 1.0); + + // Yellow/Orange heatmap color map + let red = 255; + let green = (255.0 * (1.0 - weight)) as u8; + let blue = (50.0 * (1.0 - weight)) as u8; + let color = egui::Color32::from_rgb(red, green, blue); + + // Invert Y axis so token 0 is at top + let y = (self.seq_len - 1 - row) as f64; + let x = col as f64; + + let points = Points::new( + format!("cell_{}_{}", row, col), + PlotPoints::new(vec![[x, y]]) + ) + .color(color) + .shape(MarkerShape::Square) + .radius(30.0); // Size of the square + + plot_ui.points(points); + + // Draw the weight text over the cell + plot_ui.text(Text::new(format!("text_{}_{}", row, col), + PlotPoint::new(x, y), + format!("{:.2}", weight) + ).color(if weight > 0.5 { egui::Color32::WHITE } else { egui::Color32::BLACK })); + } + } + }); + + ui.add_space(20.0); + ui.label("Tokens (Queries) on Y-axis attending to Tokens (Keys) on X-axis."); + ui.label("Top row = Token 0 Query, Left column = Token 0 Key"); + }); + } +} diff --git a/math_explorer_gui/src/tabs/ai/mod.rs b/math_explorer_gui/src/tabs/ai/mod.rs index e21fad37..da8b1c53 100644 --- a/math_explorer_gui/src/tabs/ai/mod.rs +++ b/math_explorer_gui/src/tabs/ai/mod.rs @@ -2,6 +2,7 @@ use crate::tabs::ExplorerTab; use eframe::egui; pub mod activation_functions; +pub mod attention_maps; pub mod loss_landscape; pub mod training_monitor; @@ -26,6 +27,7 @@ impl Default for AiTab { Box::new(loss_landscape::LossLandscapeTool::default()), Box::new(training_monitor::TrainingMonitorTool::default()), Box::new(activation_functions::ActivationFunctionsTool::default()), + Box::new(attention_maps::AttentionMapsTool::default()), ], selected_tool_index: 0, } diff --git a/todo_gui.md b/todo_gui.md index 4a0fff5a..915e36ed 100644 --- a/todo_gui.md +++ b/todo_gui.md @@ -96,7 +96,7 @@ This document outlines the roadmap for integrating the various modules of the `m ### 3.3 Transformers * **Module:** `ai::transformer` * **Features:** - * [ ] **Attention Maps:** Heatmap visualization of self-attention weights. + * [x] **Attention Maps:** Heatmap visualization of self-attention weights. * [ ] **Tokenization:** Text input field showing token breakdown and embeddings. ---