From 2bcd1a6d6ab351f09e0e6afe3cf78a50c4aac978 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:38:59 +0000 Subject: [PATCH 1/2] feat: implement interactive activation functions plotter in GUI - Added `tanh`, `tanh_prime`, `gelu`, and `gelu_prime` to `math_explorer` calculus module. - Created `ActivationFunctionsTool` in `math_explorer_gui` using `egui_plot`. - Integrated tool into `AiTab` adhering to Strategy Pattern. - Checked off "Activation Functions" task in `todo_gui.md`. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- .../src/ai/deep_learning_theory/calculus.rs | 36 +++++ .../src/tabs/ai/activation_functions.rs | 134 ++++++++++++++++++ math_explorer_gui/src/tabs/ai/mod.rs | 2 + todo_gui.md | 4 +- 4 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 math_explorer_gui/src/tabs/ai/activation_functions.rs diff --git a/math_explorer/src/ai/deep_learning_theory/calculus.rs b/math_explorer/src/ai/deep_learning_theory/calculus.rs index 7f29b7ed..5958bb7f 100644 --- a/math_explorer/src/ai/deep_learning_theory/calculus.rs +++ b/math_explorer/src/ai/deep_learning_theory/calculus.rs @@ -69,3 +69,39 @@ where let dz = prime_fn(z); grad_a.component_mul(&dz) // Element-wise multiplication (Hadamard product) } + +/// Tanh activation function. +pub fn tanh(z: &Vector) -> Vector { + z.map(|v| v.tanh()) +} + +/// Derivative of Tanh: 1 - tanh^2(z). +pub fn tanh_prime(z: &Vector) -> Vector { + let t = tanh(z); + t.map(|v| 1.0 - v * v) +} + +/// GELU (Gaussian Error Linear Unit) activation function. +/// Approximation: 0.5 * z * (1 + tanh(sqrt(2/pi) * (z + 0.044715 * z^3))) +pub fn gelu(z: &Vector) -> Vector { + let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt(); + z.map(|v| { + 0.5 * v * (1.0 + (sqrt_2_over_pi * (v + 0.044715 * v.powi(3))).tanh()) + }) +} + +/// Derivative of GELU (Approximation). +pub fn gelu_prime(z: &Vector) -> Vector { + let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt(); + z.map(|v| { + let x_cube = v.powi(3); + let inner = sqrt_2_over_pi * (v + 0.044715 * x_cube); + let tanh_inner = inner.tanh(); + let sech_sq = 1.0 - tanh_inner * tanh_inner; + + let term1 = 0.5 * (1.0 + tanh_inner); + let term2 = 0.5 * v * sech_sq * sqrt_2_over_pi * (1.0 + 3.0 * 0.044715 * v * v); + + term1 + term2 + }) +} diff --git a/math_explorer_gui/src/tabs/ai/activation_functions.rs b/math_explorer_gui/src/tabs/ai/activation_functions.rs new file mode 100644 index 00000000..af0108c7 --- /dev/null +++ b/math_explorer_gui/src/tabs/ai/activation_functions.rs @@ -0,0 +1,134 @@ +use crate::tabs::ai::AiTool; +use eframe::egui; +use egui_plot::{Legend, Line, Plot, PlotPoints}; +use math_explorer::ai::deep_learning_theory::calculus::{ + gelu, gelu_prime, relu, relu_prime, sigmoid, sigmoid_prime, tanh, tanh_prime, +}; +use math_explorer::ai::deep_learning_theory::linear_algebra::Vector; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ActivationFunction { + ReLU, + Sigmoid, + Tanh, + GELU, +} + +pub struct ActivationFunctionsTool { + selected_function: ActivationFunction, + x_min: f64, + x_max: f64, + points: usize, +} + +impl Default for ActivationFunctionsTool { + fn default() -> Self { + Self { + selected_function: ActivationFunction::ReLU, + x_min: -5.0, + x_max: 5.0, + points: 200, + } + } +} + +impl AiTool for ActivationFunctionsTool { + fn name(&self) -> &'static str { + "Activation Functions" + } + + fn show(&mut self, ctx: &egui::Context) { + egui::SidePanel::right("activation_functions_controls").show(ctx, |ui| { + ui.heading("Controls"); + ui.separator(); + + ui.label("Select Activation Function:"); + ui.radio_value(&mut self.selected_function, ActivationFunction::ReLU, "ReLU"); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::Sigmoid, + "Sigmoid", + ); + ui.radio_value(&mut self.selected_function, ActivationFunction::Tanh, "Tanh"); + ui.radio_value(&mut self.selected_function, ActivationFunction::GELU, "GELU"); + + ui.separator(); + ui.label("Plot Range"); + ui.horizontal(|ui| { + ui.label("X Min:"); + ui.add(egui::DragValue::new(&mut self.x_min).speed(0.1)); + }); + ui.horizontal(|ui| { + ui.label("X Max:"); + ui.add(egui::DragValue::new(&mut self.x_max).speed(0.1)); + }); + + // Ensure min < max + if self.x_min >= self.x_max { + self.x_max = self.x_min + 1.0; + } + + ui.separator(); + ui.label("Math Info"); + match self.selected_function { + ActivationFunction::ReLU => { + ui.label("f(x) = max(0, x)"); + ui.label("f'(x) = 1 if x > 0 else 0"); + } + ActivationFunction::Sigmoid => { + ui.label("f(x) = 1 / (1 + e^(-x))"); + ui.label("f'(x) = f(x) * (1 - f(x))"); + } + ActivationFunction::Tanh => { + ui.label("f(x) = tanh(x)"); + ui.label("f'(x) = 1 - tanh^2(x)"); + } + ActivationFunction::GELU => { + ui.label("f(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))"); + } + } + }); + + egui::CentralPanel::default().show(ctx, |ui| { + // Generate x values + let step = (self.x_max - self.x_min) / (self.points as f64 - 1.0); + let mut x_vals = Vec::with_capacity(self.points); + for i in 0..self.points { + x_vals.push(self.x_min + (i as f64) * step); + } + + let x_vec = Vector::from_vec(x_vals.clone()); + + // Compute y values and derivatives + let (y_vec, dy_vec) = match self.selected_function { + ActivationFunction::ReLU => (relu(&x_vec), relu_prime(&x_vec)), + ActivationFunction::Sigmoid => (sigmoid(&x_vec), sigmoid_prime(&x_vec)), + ActivationFunction::Tanh => (tanh(&x_vec), tanh_prime(&x_vec)), + ActivationFunction::GELU => (gelu(&x_vec), gelu_prime(&x_vec)), + }; + + // Map to egui_plot points + let plot_points: PlotPoints = x_vals + .iter() + .zip(y_vec.iter()) + .map(|(x, y)| [*x, *y]) + .collect(); + + let deriv_points: PlotPoints = x_vals + .iter() + .zip(dy_vec.iter()) + .map(|(x, y)| [*x, *y]) + .collect(); + + let line = Line::new("f(x)", plot_points).width(2.0); + let deriv_line = Line::new("f'(x)", deriv_points).width(1.5); + + Plot::new("activation_function_plot") + .legend(Legend::default()) + .show(ui, |plot_ui| { + plot_ui.line(line); + plot_ui.line(deriv_line); + }); + }); + } +} diff --git a/math_explorer_gui/src/tabs/ai/mod.rs b/math_explorer_gui/src/tabs/ai/mod.rs index 8e68a3e9..e21fad37 100644 --- a/math_explorer_gui/src/tabs/ai/mod.rs +++ b/math_explorer_gui/src/tabs/ai/mod.rs @@ -1,6 +1,7 @@ use crate::tabs::ExplorerTab; use eframe::egui; +pub mod activation_functions; pub mod loss_landscape; pub mod training_monitor; @@ -24,6 +25,7 @@ impl Default for AiTab { tools: vec![ Box::new(loss_landscape::LossLandscapeTool::default()), Box::new(training_monitor::TrainingMonitorTool::default()), + Box::new(activation_functions::ActivationFunctionsTool::default()), ], selected_tool_index: 0, } diff --git a/todo_gui.md b/todo_gui.md index b59ed343..4a0fff5a 100644 --- a/todo_gui.md +++ b/todo_gui.md @@ -79,12 +79,12 @@ This document outlines the roadmap for integrating the various modules of the `m ## 3. Artificial Intelligence (AI) -### 3.1 Deep Learning Theory +### 3.1 Deep Learning Theory - **[Implemented]** * **Module:** `ai::deep_learning_theory` * **Features:** * [x] **Loss Landscape:** 3D surface plot of loss functions. * [x] **Training Monitor:** Real-time curves for training/validation loss and accuracy. - * [ ] **Activation Functions:** Interactive plotter for ReLU, Sigmoid, Tanh, GELU, etc. + * [x] **Activation Functions:** Interactive plotter for ReLU, Sigmoid, Tanh, GELU, etc. ### 3.2 Reinforcement Learning * **Module:** `ai::reinforcement_learning` From 0026568fda02a749e2c0a8e312260fe7903303c7 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Wed, 4 Mar 2026 14:03:19 +0000 Subject: [PATCH 2/2] fix: resolve formatting and clippy CI failures - Ran `cargo fmt` across all crates to ensure standard styling. - Fixed `clippy::upper_case_acronyms` in `ActivationFunction` enum (`GELU` to `Gelu`). - Fixed `clippy::needless_range_loop` violations in `neural_network_viz.rs`. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- .../src/ai/deep_learning_theory/calculus.rs | 4 +--- .../src/tabs/ai/activation_functions.rs | 24 ++++++++++++++----- .../tabs/neuroscience/neural_network_viz.rs | 20 +++++++++------- .../src/tabs/number_theory/prime_spiral.rs | 7 +++++- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/math_explorer/src/ai/deep_learning_theory/calculus.rs b/math_explorer/src/ai/deep_learning_theory/calculus.rs index 5958bb7f..a1bae15a 100644 --- a/math_explorer/src/ai/deep_learning_theory/calculus.rs +++ b/math_explorer/src/ai/deep_learning_theory/calculus.rs @@ -85,9 +85,7 @@ pub fn tanh_prime(z: &Vector) -> Vector { /// Approximation: 0.5 * z * (1 + tanh(sqrt(2/pi) * (z + 0.044715 * z^3))) pub fn gelu(z: &Vector) -> Vector { let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt(); - z.map(|v| { - 0.5 * v * (1.0 + (sqrt_2_over_pi * (v + 0.044715 * v.powi(3))).tanh()) - }) + z.map(|v| 0.5 * v * (1.0 + (sqrt_2_over_pi * (v + 0.044715 * v.powi(3))).tanh())) } /// Derivative of GELU (Approximation). diff --git a/math_explorer_gui/src/tabs/ai/activation_functions.rs b/math_explorer_gui/src/tabs/ai/activation_functions.rs index af0108c7..a5f924ca 100644 --- a/math_explorer_gui/src/tabs/ai/activation_functions.rs +++ b/math_explorer_gui/src/tabs/ai/activation_functions.rs @@ -11,7 +11,7 @@ enum ActivationFunction { ReLU, Sigmoid, Tanh, - GELU, + Gelu, } pub struct ActivationFunctionsTool { @@ -43,14 +43,26 @@ impl AiTool for ActivationFunctionsTool { ui.separator(); ui.label("Select Activation Function:"); - ui.radio_value(&mut self.selected_function, ActivationFunction::ReLU, "ReLU"); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::ReLU, + "ReLU", + ); ui.radio_value( &mut self.selected_function, ActivationFunction::Sigmoid, "Sigmoid", ); - ui.radio_value(&mut self.selected_function, ActivationFunction::Tanh, "Tanh"); - ui.radio_value(&mut self.selected_function, ActivationFunction::GELU, "GELU"); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::Tanh, + "Tanh", + ); + ui.radio_value( + &mut self.selected_function, + ActivationFunction::Gelu, + "GELU", + ); ui.separator(); ui.label("Plot Range"); @@ -83,7 +95,7 @@ impl AiTool for ActivationFunctionsTool { ui.label("f(x) = tanh(x)"); ui.label("f'(x) = 1 - tanh^2(x)"); } - ActivationFunction::GELU => { + ActivationFunction::Gelu => { ui.label("f(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))"); } } @@ -104,7 +116,7 @@ impl AiTool for ActivationFunctionsTool { ActivationFunction::ReLU => (relu(&x_vec), relu_prime(&x_vec)), ActivationFunction::Sigmoid => (sigmoid(&x_vec), sigmoid_prime(&x_vec)), ActivationFunction::Tanh => (tanh(&x_vec), tanh_prime(&x_vec)), - ActivationFunction::GELU => (gelu(&x_vec), gelu_prime(&x_vec)), + ActivationFunction::Gelu => (gelu(&x_vec), gelu_prime(&x_vec)), }; // Map to egui_plot points diff --git a/math_explorer_gui/src/tabs/neuroscience/neural_network_viz.rs b/math_explorer_gui/src/tabs/neuroscience/neural_network_viz.rs index fe76aa87..60c38832 100644 --- a/math_explorer_gui/src/tabs/neuroscience/neural_network_viz.rs +++ b/math_explorer_gui/src/tabs/neuroscience/neural_network_viz.rs @@ -39,8 +39,12 @@ impl Default for NeuralNetworkVizTool { let next = (i + 1) % num_neurons; let prev = (i + num_neurons - 1) % num_neurons; - weights[next][i] = 2.0; // Excitatory connection - weights[prev][i] = -1.0; // Inhibitory connection + if let Some(row) = weights.get_mut(next) { + row[i] = 2.0; // Excitatory connection + } + if let Some(row) = weights.get_mut(prev) { + row[i] = -1.0; // Inhibitory connection + } } // Inject some initial current into the first neuron to kickstart @@ -68,16 +72,16 @@ impl NeuralNetworkVizTool { // Compute synaptic currents. // Simple model: if a presynaptic neuron is firing (V > 0), inject current. - for i in 0..num_neurons { - for j in 0..num_neurons { - if self.neurons[j].v() > 0.0 { - next_inputs[i] += self.weights[i][j] * 50.0; // Synaptic strength scalar + for (i, input) in next_inputs.iter_mut().enumerate().take(num_neurons) { + for (j, neuron) in self.neurons.iter().enumerate().take(num_neurons) { + if neuron.v() > 0.0 { + *input += self.weights[i][j] * 50.0; // Synaptic strength scalar } } } - for i in 0..num_neurons { - let total_i_ext = self.external_input[i] + next_inputs[i]; + for (i, input) in next_inputs.iter().enumerate().take(num_neurons) { + let total_i_ext = self.external_input[i] + input; self.neurons[i].update(self.dt, total_i_ext); } } diff --git a/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs b/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs index 10fc4ec6..1f7077b8 100644 --- a/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs +++ b/math_explorer_gui/src/tabs/number_theory/prime_spiral.rs @@ -62,7 +62,12 @@ impl PrimeSpiralWidget { is_prime_lookup[p as usize] = true; } - for (i, &is_prime) in is_prime_lookup.iter().enumerate().take(num_pixels + 1).skip(1) { + for (i, &is_prime) in is_prime_lookup + .iter() + .enumerate() + .take(num_pixels + 1) + .skip(1) + { // Fix unary negation on usize by casting to i32 first if (-(size as i32) / 2 <= x) && (x <= size as i32 / 2)