From 9c3905420122ff5e7505a469a8bcf981fe81d245 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:37:29 +0000 Subject: [PATCH] feat(gui): Add Attention Maps visualization tool - Implement `AttentionMapsTool` for visualizing transformer self-attention mechanisms. - Provide interactive inputs for Queries (Q), Keys (K), and Values (V) using `egui::Grid`. - Visualize resulting Attention Weights and final Output matrices using custom egui heatmaps. - Integrate tool into the existing `AiTab` via the `AiTool` strategy pattern. - Update `todo_gui.md` to mark the feature as complete. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- .../src/tabs/ai/attention_maps.rs | 169 ++++++++++++++++++ math_explorer_gui/src/tabs/ai/mod.rs | 2 + todo_gui.md | 2 +- 3 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 math_explorer_gui/src/tabs/ai/attention_maps.rs diff --git a/math_explorer_gui/src/tabs/ai/attention_maps.rs b/math_explorer_gui/src/tabs/ai/attention_maps.rs new file mode 100644 index 00000000..e64210c1 --- /dev/null +++ b/math_explorer_gui/src/tabs/ai/attention_maps.rs @@ -0,0 +1,169 @@ +use super::AiTool; +use eframe::egui; +use nalgebra::DMatrix; + +pub struct AttentionMapsTool { + q_matrix: DMatrix, + k_matrix: DMatrix, + v_matrix: DMatrix, + output_matrix: DMatrix, + attention_weights: DMatrix, +} + +impl Default for AttentionMapsTool { + fn default() -> Self { + let seq_len = 4; + let d_k = 4; + + let mut tool = Self { + q_matrix: DMatrix::from_element(seq_len, d_k, 0.5), + k_matrix: DMatrix::from_element(seq_len, d_k, 0.5), + v_matrix: DMatrix::from_element(seq_len, d_k, 0.5), + output_matrix: DMatrix::zeros(seq_len, d_k), + attention_weights: DMatrix::zeros(seq_len, seq_len), + }; + tool.recalculate(); + tool + } +} + +impl AttentionMapsTool { + fn recalculate(&mut self) { + let (output, weights) = math_explorer::ai::transformer::attention::scaled_dot_product_attention( + &self.q_matrix, + &self.k_matrix, + &self.v_matrix, + None, + ); + self.output_matrix = output; + self.attention_weights = weights; + } + + fn draw_matrix_input(ui: &mut egui::Ui, name: &str, matrix: &mut DMatrix) -> bool { + let mut changed = false; + ui.label(name); + egui::Grid::new(format!("{}_input_grid", name)) + .num_columns(matrix.ncols()) + .show(ui, |ui| { + for r in 0..matrix.nrows() { + for c in 0..matrix.ncols() { + let mut val = matrix[(r, c)]; + if ui + .add( + egui::DragValue::new(&mut val) + .speed(0.1) + .range(-10.0..=10.0), + ) + .changed() + { + matrix[(r, c)] = val; + changed = true; + } + } + ui.end_row(); + } + }); + changed + } + + fn draw_heatmap(ui: &mut egui::Ui, name: &str, matrix: &DMatrix) { + ui.label(name); + + let min_val = matrix.iter().cloned().fold(f64::INFINITY, f64::min); + let max_val = matrix.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + + egui::Grid::new(format!("{}_heatmap_grid", name)) + .num_columns(matrix.ncols()) + .spacing([2.0, 2.0]) + .show(ui, |ui| { + for r in 0..matrix.nrows() { + for c in 0..matrix.ncols() { + let val = matrix[(r, c)]; + + let normalized = if (max_val - min_val).abs() > 1e-6 { + (val - min_val) / (max_val - min_val) + } else { + 0.5 + }; + + // Color ranges from dark blue (low) to yellow (high) + let color = egui::Color32::from_rgb( + (normalized * 255.0) as u8, + (normalized * 200.0) as u8, + ((1.0 - normalized) * 200.0 + 55.0) as u8, + ); + + let (rect, _response) = ui.allocate_exact_size( + egui::vec2(40.0, 40.0), + egui::Sense::hover(), + ); + + ui.painter().rect_filled(rect, 2.0, color); + + let text_color = if normalized > 0.5 { + egui::Color32::BLACK + } else { + egui::Color32::WHITE + }; + + ui.painter().text( + rect.center(), + egui::Align2::CENTER_CENTER, + format!("{:.2}", val), + egui::FontId::proportional(12.0), + text_color, + ); + } + ui.end_row(); + } + }); + } +} + +impl AiTool for AttentionMapsTool { + fn name(&self) -> &'static str { + "Attention Maps" + } + + fn show(&mut self, ctx: &egui::Context) { + egui::CentralPanel::default().show(ctx, |ui| { + ui.heading("Self-Attention Mechanism"); + ui.label("Explore how Queries (Q), Keys (K), and Values (V) interact."); + + egui::ScrollArea::vertical().show(ui, |ui| { + ui.horizontal(|ui| { + ui.vertical(|ui| { + let mut changed = false; + ui.group(|ui| { + changed |= Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix); + }); + ui.group(|ui| { + changed |= Self::draw_matrix_input(ui, "Keys (K)", &mut self.k_matrix); + }); + ui.group(|ui| { + changed |= Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix); + }); + + if changed { + self.recalculate(); + } + }); + + ui.separator(); + + ui.vertical(|ui| { + ui.group(|ui| { + Self::draw_heatmap(ui, "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", &self.attention_weights); + }); + + ui.add_space(20.0); + + ui.group(|ui| { + Self::draw_heatmap(ui, "Output (Weights * V)", &self.output_matrix); + }); + }); + }); + }); + }); + } +} diff --git a/math_explorer_gui/src/tabs/ai/mod.rs b/math_explorer_gui/src/tabs/ai/mod.rs index e21fad37..da8b1c53 100644 --- a/math_explorer_gui/src/tabs/ai/mod.rs +++ b/math_explorer_gui/src/tabs/ai/mod.rs @@ -2,6 +2,7 @@ use crate::tabs::ExplorerTab; use eframe::egui; pub mod activation_functions; +pub mod attention_maps; pub mod loss_landscape; pub mod training_monitor; @@ -26,6 +27,7 @@ impl Default for AiTab { Box::new(loss_landscape::LossLandscapeTool::default()), Box::new(training_monitor::TrainingMonitorTool::default()), Box::new(activation_functions::ActivationFunctionsTool::default()), + Box::new(attention_maps::AttentionMapsTool::default()), ], selected_tool_index: 0, } diff --git a/todo_gui.md b/todo_gui.md index 4a0fff5a..915e36ed 100644 --- a/todo_gui.md +++ b/todo_gui.md @@ -96,7 +96,7 @@ This document outlines the roadmap for integrating the various modules of the `m ### 3.3 Transformers * **Module:** `ai::transformer` * **Features:** - * [ ] **Attention Maps:** Heatmap visualization of self-attention weights. + * [x] **Attention Maps:** Heatmap visualization of self-attention weights. * [ ] **Tokenization:** Text input field showing token breakdown and embeddings. ---