-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Implement Activation Functions Plotter in AI Tab #646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| use crate::tabs::ai::AiTool; | ||
| use eframe::egui; | ||
| use egui_plot::{Legend, Line, Plot, PlotPoints}; | ||
| use math_explorer::ai::deep_learning_theory::calculus::{ | ||
| gelu, gelu_prime, relu, relu_prime, sigmoid, sigmoid_prime, tanh, tanh_prime, | ||
| }; | ||
| use math_explorer::ai::deep_learning_theory::linear_algebra::Vector; | ||
|
|
||
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] | ||
| enum ActivationFunction { | ||
| ReLU, | ||
| Sigmoid, | ||
| Tanh, | ||
| Gelu, | ||
| } | ||
|
|
||
| pub struct ActivationFunctionsTool { | ||
| selected_function: ActivationFunction, | ||
| x_min: f64, | ||
| x_max: f64, | ||
| points: usize, | ||
| } | ||
|
|
||
| impl Default for ActivationFunctionsTool { | ||
| fn default() -> Self { | ||
| Self { | ||
| selected_function: ActivationFunction::ReLU, | ||
| x_min: -5.0, | ||
| x_max: 5.0, | ||
| points: 200, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl AiTool for ActivationFunctionsTool { | ||
| fn name(&self) -> &'static str { | ||
| "Activation Functions" | ||
| } | ||
|
|
||
| fn show(&mut self, ctx: &egui::Context) { | ||
| egui::SidePanel::right("activation_functions_controls").show(ctx, |ui| { | ||
| ui.heading("Controls"); | ||
| ui.separator(); | ||
|
|
||
| ui.label("Select Activation Function:"); | ||
| ui.radio_value( | ||
| &mut self.selected_function, | ||
| ActivationFunction::ReLU, | ||
| "ReLU", | ||
| ); | ||
| ui.radio_value( | ||
| &mut self.selected_function, | ||
| ActivationFunction::Sigmoid, | ||
| "Sigmoid", | ||
| ); | ||
| ui.radio_value( | ||
| &mut self.selected_function, | ||
| ActivationFunction::Tanh, | ||
| "Tanh", | ||
| ); | ||
| ui.radio_value( | ||
| &mut self.selected_function, | ||
| ActivationFunction::Gelu, | ||
| "GELU", | ||
| ); | ||
|
|
||
| ui.separator(); | ||
| ui.label("Plot Range"); | ||
| ui.horizontal(|ui| { | ||
| ui.label("X Min:"); | ||
| ui.add(egui::DragValue::new(&mut self.x_min).speed(0.1)); | ||
| }); | ||
| ui.horizontal(|ui| { | ||
| ui.label("X Max:"); | ||
| ui.add(egui::DragValue::new(&mut self.x_max).speed(0.1)); | ||
| }); | ||
|
|
||
| // Ensure min < max | ||
| if self.x_min >= self.x_max { | ||
| self.x_max = self.x_min + 1.0; | ||
| } | ||
|
|
||
| ui.separator(); | ||
| ui.label("Math Info"); | ||
| match self.selected_function { | ||
| ActivationFunction::ReLU => { | ||
| ui.label("f(x) = max(0, x)"); | ||
| ui.label("f'(x) = 1 if x > 0 else 0"); | ||
| } | ||
| ActivationFunction::Sigmoid => { | ||
| ui.label("f(x) = 1 / (1 + e^(-x))"); | ||
| ui.label("f'(x) = f(x) * (1 - f(x))"); | ||
| } | ||
| ActivationFunction::Tanh => { | ||
| ui.label("f(x) = tanh(x)"); | ||
| ui.label("f'(x) = 1 - tanh^2(x)"); | ||
| } | ||
| ActivationFunction::Gelu => { | ||
| ui.label("f(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))"); | ||
| } | ||
| } | ||
| }); | ||
|
|
||
| egui::CentralPanel::default().show(ctx, |ui| { | ||
| // Generate x values | ||
| let step = (self.x_max - self.x_min) / (self.points as f64 - 1.0); | ||
| let mut x_vals = Vec::with_capacity(self.points); | ||
| for i in 0..self.points { | ||
| x_vals.push(self.x_min + (i as f64) * step); | ||
| } | ||
|
|
||
| let x_vec = Vector::from_vec(x_vals.clone()); | ||
|
|
||
| // Compute y values and derivatives | ||
| let (y_vec, dy_vec) = match self.selected_function { | ||
| ActivationFunction::ReLU => (relu(&x_vec), relu_prime(&x_vec)), | ||
| ActivationFunction::Sigmoid => (sigmoid(&x_vec), sigmoid_prime(&x_vec)), | ||
| ActivationFunction::Tanh => (tanh(&x_vec), tanh_prime(&x_vec)), | ||
| ActivationFunction::Gelu => (gelu(&x_vec), gelu_prime(&x_vec)), | ||
| }; | ||
|
|
||
| // Map to egui_plot points | ||
| let plot_points: PlotPoints = x_vals | ||
| .iter() | ||
| .zip(y_vec.iter()) | ||
| .map(|(x, y)| [*x, *y]) | ||
| .collect(); | ||
|
|
||
| let deriv_points: PlotPoints = x_vals | ||
| .iter() | ||
| .zip(dy_vec.iter()) | ||
| .map(|(x, y)| [*x, *y]) | ||
| .collect(); | ||
|
|
||
| let line = Line::new("f(x)", plot_points).width(2.0); | ||
| let deriv_line = Line::new("f'(x)", deriv_points).width(1.5); | ||
|
|
||
| Plot::new("activation_function_plot") | ||
| .legend(Legend::default()) | ||
| .show(ui, |plot_ui| { | ||
| plot_ui.line(line); | ||
| plot_ui.line(deriv_line); | ||
| }); | ||
| }); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -39,8 +39,12 @@ impl Default for NeuralNetworkVizTool { | |||||||||||||||||
| let next = (i + 1) % num_neurons; | ||||||||||||||||||
| let prev = (i + num_neurons - 1) % num_neurons; | ||||||||||||||||||
|
|
||||||||||||||||||
| weights[next][i] = 2.0; // Excitatory connection | ||||||||||||||||||
| weights[prev][i] = -1.0; // Inhibitory connection | ||||||||||||||||||
| if let Some(row) = weights.get_mut(next) { | ||||||||||||||||||
| row[i] = 2.0; // Excitatory connection | ||||||||||||||||||
| } | ||||||||||||||||||
| if let Some(row) = weights.get_mut(prev) { | ||||||||||||||||||
| row[i] = -1.0; // Inhibitory connection | ||||||||||||||||||
| } | ||||||||||||||||||
|
Comment on lines
+42
to
+47
|
||||||||||||||||||
| if let Some(row) = weights.get_mut(next) { | |
| row[i] = 2.0; // Excitatory connection | |
| } | |
| if let Some(row) = weights.get_mut(prev) { | |
| row[i] = -1.0; // Inhibitory connection | |
| } | |
| weights[next][i] = 2.0; // Excitatory connection | |
| weights[prev][i] = -1.0; // Inhibitory connection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stepis computed as(x_max - x_min) / (points - 1). Ifpointsis ever set to 0 or 1, this will produce an invalid step (division by zero / infinities) and can yield NaNs in the plotted data. Consider clampingpointsto at least 2 (or handling thepoints <= 1case explicitly) before computingstep.