From 877b3b48d04732c81e2fe9d24bcda8d1cbded4b5 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Tue, 3 Mar 2026 13:37:38 +0000 Subject: [PATCH 1/2] feat(gui): implement Activation Functions interactive plotter - Adds `tanh`, `tanh_prime`, `gelu`, and `gelu_prime` functions to `math_explorer/src/ai/deep_learning_theory/calculus.rs`. - Creates `ActivationFunctionsTool` in `math_explorer_gui/src/tabs/ai/activation_functions.rs` using `egui_plot` to visualize the functions and their derivatives. - Integrates the new tool into `AiTab` (`math_explorer_gui/src/tabs/ai/mod.rs`). - Updates `todo_gui.md` marking the task as complete. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- .../src/ai/deep_learning_theory/calculus.rs | 30 ++++ .../src/tabs/ai/activation_functions.rs | 134 ++++++++++++++++++ math_explorer_gui/src/tabs/ai/mod.rs | 2 + .../src/tabs/number_theory/prime_spiral.rs | 7 +- todo_gui.md | 2 +- 5 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 math_explorer_gui/src/tabs/ai/activation_functions.rs diff --git a/math_explorer/src/ai/deep_learning_theory/calculus.rs b/math_explorer/src/ai/deep_learning_theory/calculus.rs index 7f29b7ed..a13f7d0c 100644 --- a/math_explorer/src/ai/deep_learning_theory/calculus.rs +++ b/math_explorer/src/ai/deep_learning_theory/calculus.rs @@ -69,3 +69,33 @@ where let dz = prime_fn(z); grad_a.component_mul(&dz) // Element-wise multiplication (Hadamard product) } + +/// Tanh activation function: f(z) = tanh(z). +pub fn tanh(z: &Vector) -> Vector { + z.map(|v| v.tanh()) +} + +/// Derivative of Tanh: f'(z) = 1 - tanh^2(z). +pub fn tanh_prime(z: &Vector) -> Vector { + let t = tanh(z); + t.map(|v| 1.0 - v * v) +} + +/// GELU (Gaussian Error Linear Unit) activation function: f(z) = 0.5 * z * (1 + tanh(sqrt(2/pi) * (z + 0.044715 * z^3))). +/// Using the approximate formulation. +pub fn gelu(z: &Vector) -> Vector { + let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt(); + z.map(|v| 0.5 * v * (1.0 + (sqrt_2_over_pi * (v + 0.044715 * v.powi(3))).tanh())) +} + +/// Derivative of GELU (approximate). +pub fn gelu_prime(z: &Vector) -> Vector { + let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt(); + z.map(|v| { + let x = sqrt_2_over_pi * (v + 0.044715 * v.powi(3)); + let tanh_x = x.tanh(); + let sech2_x = 1.0 - tanh_x * tanh_x; + 0.5 * (1.0 + tanh_x) + + 0.5 * v * sech2_x * sqrt_2_over_pi * (1.0 + 3.0 * 0.044715 * v.powi(2)) + }) +} diff --git a/math_explorer_gui/src/tabs/ai/activation_functions.rs b/math_explorer_gui/src/tabs/ai/activation_functions.rs new file mode 100644 index 00000000..12fc2e96 --- /dev/null +++ b/math_explorer_gui/src/tabs/ai/activation_functions.rs @@ -0,0 +1,134 @@ +use super::AiTool; +use eframe::egui; +use egui_plot::{Line, LineStyle, Plot, PlotPoints}; +use math_explorer::ai::deep_learning_theory::calculus::{ + gelu, gelu_prime, relu, relu_prime, sigmoid, sigmoid_prime, tanh, tanh_prime, +}; +use nalgebra::DVector; + +#[derive(PartialEq, Clone, Copy)] +enum ActivationFunction { + ReLU, + Sigmoid, + Tanh, + GELU, +} + +impl ActivationFunction { + fn name(&self) -> &'static str { + match self { + ActivationFunction::ReLU => "ReLU", + ActivationFunction::Sigmoid => "Sigmoid", + ActivationFunction::Tanh => "Tanh", + ActivationFunction::GELU => "GELU", + } + } +} + +pub struct ActivationFunctionsTool { + selected_function: ActivationFunction, + show_derivative: bool, + x_min: f64, + x_max: f64, + points_count: usize, +} + +impl Default for ActivationFunctionsTool { + fn default() -> Self { + Self { + selected_function: ActivationFunction::ReLU, + show_derivative: false, + x_min: -5.0, + x_max: 5.0, + points_count: 500, + } + } +} + +impl AiTool for ActivationFunctionsTool { + fn name(&self) -> &'static str { + "Activation Functions" + } + + fn show(&mut self, ctx: &egui::Context) { + egui::SidePanel::left("activation_controls").show(ctx, |ui| { + ui.heading("Controls"); + ui.separator(); + + ui.label("Select Function:"); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::ReLU, + "ReLU", + ); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::Sigmoid, + "Sigmoid", + ); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::Tanh, + "Tanh", + ); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::GELU, + "GELU", + ); + + ui.separator(); + ui.checkbox(&mut self.show_derivative, "Show Derivative"); + + ui.separator(); + ui.label("Domain:"); + ui.add(egui::Slider::new(&mut self.x_min, -10.0..=-1.0).text("Min X")); + ui.add(egui::Slider::new(&mut self.x_max, 1.0..=10.0).text("Max X")); + }); + + egui::CentralPanel::default().show(ctx, |ui| { + ui.heading("Activation Function Plot"); + + // Generate data points + let step = (self.x_max - self.x_min) / (self.points_count as f64 - 1.0); + let x_vals: Vec = (0..self.points_count) + .map(|i| self.x_min + i as f64 * step) + .collect(); + let x_vec = DVector::from_vec(x_vals.clone()); + + let (y_vals, y_prime_vals) = match self.selected_function { + ActivationFunction::ReLU => (relu(&x_vec), relu_prime(&x_vec)), + ActivationFunction::Sigmoid => (sigmoid(&x_vec), sigmoid_prime(&x_vec)), + ActivationFunction::Tanh => (tanh(&x_vec), tanh_prime(&x_vec)), + ActivationFunction::GELU => (gelu(&x_vec), gelu_prime(&x_vec)), + }; + + let points: PlotPoints = x_vals + .iter() + .zip(y_vals.iter()) + .map(|(&x, &y)| [x, y]) + .collect(); + let line = Line::new(self.selected_function.name(), points); + + let plot = Plot::new("activation_plot") + .legend(egui_plot::Legend::default()) + .view_aspect(2.0); + + plot.show(ui, |plot_ui| { + plot_ui.line(line); + + if self.show_derivative { + let prime_points: PlotPoints = x_vals + .iter() + .zip(y_prime_vals.iter()) + .map(|(&x, &y)| [x, y]) + .collect(); + let prime_line = + Line::new(format!("{}'", self.selected_function.name()), prime_points) + .style(LineStyle::Dashed { length: 5.0 }); + plot_ui.line(prime_line); + } + }); + }); + } +} diff --git a/math_explorer_gui/src/tabs/ai/mod.rs b/math_explorer_gui/src/tabs/ai/mod.rs index 8e68a3e9..e21fad37 100644 --- a/math_explorer_gui/src/tabs/ai/mod.rs +++ b/math_explorer_gui/src/tabs/ai/mod.rs @@ -1,6 +1,7 @@ use crate::tabs::ExplorerTab; use eframe::egui; +pub mod activation_functions; pub mod loss_landscape; pub mod training_monitor; @@ -24,6 +25,7 @@ impl Default for AiTab { tools: vec![ Box::new(loss_landscape::LossLandscapeTool::default()), Box::new(training_monitor::TrainingMonitorTool::default()), + Box::new(activation_functions::ActivationFunctionsTool::default()), ], selected_tool_index: 0, } diff --git a/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs b/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs index 10fc4ec6..1f7077b8 100644 --- a/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs +++ b/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs @@ -62,7 +62,12 @@ impl PrimeSpiralWidget { is_prime_lookup[p as usize] = true; } - for (i, &is_prime) in is_prime_lookup.iter().enumerate().take(num_pixels + 1).skip(1) { + for (i, &is_prime) in is_prime_lookup + .iter() + .enumerate() + .take(num_pixels + 1) + .skip(1) + { // Fix unary negation on usize by casting to i32 first if (-(size as i32) / 2 <= x) && (x <= size as i32 / 2) diff --git a/todo_gui.md b/todo_gui.md index b59ed343..fb4c9a76 100644 --- a/todo_gui.md +++ b/todo_gui.md @@ -84,7 +84,7 @@ This document outlines the roadmap for integrating the various modules of the `m * **Features:** * [x] **Loss Landscape:** 3D surface plot of loss functions. * [x] **Training Monitor:** Real-time curves for training/validation loss and accuracy. - * [ ] **Activation Functions:** Interactive plotter for ReLU, Sigmoid, Tanh, GELU, etc. + * [x] **Activation Functions:** Interactive plotter for ReLU, Sigmoid, Tanh, GELU, etc. ### 3.2 Reinforcement Learning * **Module:** `ai::reinforcement_learning` From 77f2388a885de2353ef8348bb5fc19449590e11b Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Tue, 3 Mar 2026 14:47:31 +0000 Subject: [PATCH 2/2] feat(gui): implement Activation Functions interactive plotter Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>