Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions math_explorer/src/ai/deep_learning_theory/calculus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,33 @@ where
let dz = prime_fn(z);
grad_a.component_mul(&dz) // Element-wise multiplication (Hadamard product)
}

/// Tanh activation function: f(z) = tanh(z).
pub fn tanh(z: &Vector) -> Vector {
z.map(|v| v.tanh())
}

/// Derivative of Tanh: f'(z) = 1 - tanh^2(z).
pub fn tanh_prime(z: &Vector) -> Vector {
let t = tanh(z);
t.map(|v| 1.0 - v * v)
}

/// GELU (Gaussian Error Linear Unit) activation function: f(z) = 0.5 * z * (1 + tanh(sqrt(2/pi) * (z + 0.044715 * z^3))).
/// Using the approximate formulation.
pub fn gelu(z: &Vector) -> Vector {
let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt();
z.map(|v| 0.5 * v * (1.0 + (sqrt_2_over_pi * (v + 0.044715 * v.powi(3))).tanh()))
}

/// Derivative of GELU (approximate).
pub fn gelu_prime(z: &Vector) -> Vector {
let sqrt_2_over_pi = (2.0f64 / std::f64::consts::PI).sqrt();
z.map(|v| {
let x = sqrt_2_over_pi * (v + 0.044715 * v.powi(3));
let tanh_x = x.tanh();
let sech2_x = 1.0 - tanh_x * tanh_x;
0.5 * (1.0 + tanh_x)
+ 0.5 * v * sech2_x * sqrt_2_over_pi * (1.0 + 3.0 * 0.044715 * v.powi(2))
Comment on lines +84 to +99
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gelu/gelu_prime repeat the same constants (0.044715 and sqrt_2_over_pi) in multiple places. Consider extracting these into named consts (e.g., GELU_COEFF, SQRT_2_OVER_PI) or a shared helper to avoid magic numbers and keep the approximation/derivative consistent if the coefficients ever need to change.

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

})
}
134 changes: 134 additions & 0 deletions math_explorer_gui/src/tabs/ai/activation_functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use super::AiTool;
use eframe::egui;
use egui_plot::{Line, LineStyle, Plot, PlotPoints};
use math_explorer::ai::deep_learning_theory::calculus::{
gelu, gelu_prime, relu, relu_prime, sigmoid, sigmoid_prime, tanh, tanh_prime,
};
use nalgebra::DVector;

#[derive(PartialEq, Clone, Copy)]
enum ActivationFunction {
ReLU,
Sigmoid,
Tanh,
GELU,
}

impl ActivationFunction {
fn name(&self) -> &'static str {
match self {
ActivationFunction::ReLU => "ReLU",
ActivationFunction::Sigmoid => "Sigmoid",
ActivationFunction::Tanh => "Tanh",
ActivationFunction::GELU => "GELU",
}
}
}

pub struct ActivationFunctionsTool {
selected_function: ActivationFunction,
show_derivative: bool,
x_min: f64,
x_max: f64,
points_count: usize,
}

impl Default for ActivationFunctionsTool {
fn default() -> Self {
Self {
selected_function: ActivationFunction::ReLU,
show_derivative: false,
x_min: -5.0,
x_max: 5.0,
points_count: 500,
}
}
}

impl AiTool for ActivationFunctionsTool {
fn name(&self) -> &'static str {
"Activation Functions"
}

fn show(&mut self, ctx: &egui::Context) {
egui::SidePanel::left("activation_controls").show(ctx, |ui| {
ui.heading("Controls");
ui.separator();

ui.label("Select Function:");
ui.radio_value(
&mut self.selected_function,
ActivationFunction::ReLU,
"ReLU",
);
ui.radio_value(
&mut self.selected_function,
ActivationFunction::Sigmoid,
"Sigmoid",
);
ui.radio_value(
&mut self.selected_function,
ActivationFunction::Tanh,
"Tanh",
);
ui.radio_value(
&mut self.selected_function,
ActivationFunction::GELU,
"GELU",
);

ui.separator();
ui.checkbox(&mut self.show_derivative, "Show Derivative");

ui.separator();
ui.label("Domain:");
ui.add(egui::Slider::new(&mut self.x_min, -10.0..=-1.0).text("Min X"));
ui.add(egui::Slider::new(&mut self.x_max, 1.0..=10.0).text("Max X"));
Comment on lines +84 to +86
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The X-axis domain sliders are constrained to x_min ∈ [-10, -1] and x_max ∈ [1, 10], which prevents exploring purely-positive/negative domains and also prevents zooming in to narrow windows around 0 (e.g., [-0.5, 0.5]). Consider letting both sliders span a shared range (e.g., -10..=10) and enforcing x_min < x_max via clamping/swapping (and optionally showing a validation message) so the domain is truly adjustable.

Copilot uses AI. Check for mistakes.
});

egui::CentralPanel::default().show(ctx, |ui| {
ui.heading("Activation Function Plot");

// Generate data points
let step = (self.x_max - self.x_min) / (self.points_count as f64 - 1.0);
let x_vals: Vec<f64> = (0..self.points_count)
.map(|i| self.x_min + i as f64 * step)
.collect();
let x_vec = DVector::from_vec(x_vals.clone());

Comment on lines +93 to +98
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_vals is cloned to build x_vec (DVector::from_vec(x_vals.clone())), which duplicates allocations/work every frame. Since x_vals is only used to build plot points, consider generating x coordinates once (e.g., directly as a DVector/iterator) and using that single source to both compute y-values and build PlotPoints, avoiding the extra clone.

Copilot uses AI. Check for mistakes.
let (y_vals, y_prime_vals) = match self.selected_function {
ActivationFunction::ReLU => (relu(&x_vec), relu_prime(&x_vec)),
ActivationFunction::Sigmoid => (sigmoid(&x_vec), sigmoid_prime(&x_vec)),
ActivationFunction::Tanh => (tanh(&x_vec), tanh_prime(&x_vec)),
ActivationFunction::GELU => (gelu(&x_vec), gelu_prime(&x_vec)),
};

let points: PlotPoints = x_vals
.iter()
.zip(y_vals.iter())
.map(|(&x, &y)| [x, y])
.collect();
let line = Line::new(self.selected_function.name(), points);

let plot = Plot::new("activation_plot")
.legend(egui_plot::Legend::default())
.view_aspect(2.0);

plot.show(ui, |plot_ui| {
plot_ui.line(line);

if self.show_derivative {
let prime_points: PlotPoints = x_vals
.iter()
.zip(y_prime_vals.iter())
.map(|(&x, &y)| [x, y])
.collect();
let prime_line =
Line::new(format!("{}'", self.selected_function.name()), prime_points)
.style(LineStyle::Dashed { length: 5.0 });
plot_ui.line(prime_line);
}
});
});
}
}
2 changes: 2 additions & 0 deletions math_explorer_gui/src/tabs/ai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::tabs::ExplorerTab;
use eframe::egui;

pub mod activation_functions;
pub mod loss_landscape;
pub mod training_monitor;

Expand All @@ -24,6 +25,7 @@ impl Default for AiTab {
tools: vec![
Box::new(loss_landscape::LossLandscapeTool::default()),
Box::new(training_monitor::TrainingMonitorTool::default()),
Box::new(activation_functions::ActivationFunctionsTool::default()),
],
selected_tool_index: 0,
}
Expand Down
7 changes: 6 additions & 1 deletion math_explorer_gui/src/tabs/number_theory/prime_spiral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ impl PrimeSpiralWidget {
is_prime_lookup[p as usize] = true;
}

for (i, &is_prime) in is_prime_lookup.iter().enumerate().take(num_pixels + 1).skip(1) {
for (i, &is_prime) in is_prime_lookup
.iter()
.enumerate()
.take(num_pixels + 1)
.skip(1)
{
// Fix unary negation on usize by casting to i32 first
if (-(size as i32) / 2 <= x)
&& (x <= size as i32 / 2)
Expand Down
2 changes: 1 addition & 1 deletion todo_gui.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ This document outlines the roadmap for integrating the various modules of the `m
* **Features:**
* [x] **Loss Landscape:** 3D surface plot of loss functions.
* [x] **Training Monitor:** Real-time curves for training/validation loss and accuracy.
* [ ] **Activation Functions:** Interactive plotter for ReLU, Sigmoid, Tanh, GELU, etc.
* [x] **Activation Functions:** Interactive plotter for ReLU, Sigmoid, Tanh, GELU, etc.

### 3.2 Reinforcement Learning
* **Module:** `ai::reinforcement_learning`
Expand Down