diff --git a/math_explorer/src/epidemiology/networks.rs b/math_explorer/src/epidemiology/networks.rs index 3f09ad7c..ed442f77 100644 --- a/math_explorer/src/epidemiology/networks.rs +++ b/math_explorer/src/epidemiology/networks.rs @@ -1,3 +1,6 @@ +use rand::Rng; +use std::f32::consts::TAU; + /// Calculates R0 for a heterogeneous network. /// /// $R_0 = \frac{\beta}{\gamma} \frac{\langle k^2 \rangle - \langle k \rangle}{\langle k \rangle}$ @@ -15,6 +18,101 @@ pub fn heterogeneous_r0(beta: f64, gamma: f64, mean_degree: f64, degree_variance (beta / gamma) * factor } +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum NodeState { + Susceptible, + Infected, + Recovered, +} + +pub struct NetworkEpidemicModel { + pub num_nodes: usize, + pub states: Vec, + pub positions: Vec<[f32; 2]>, + pub adjacency: Vec>, + pub beta: f64, + pub gamma: f64, +} + +impl NetworkEpidemicModel { + pub fn new(num_nodes: usize, beta: f64, gamma: f64) -> Self { + let mut model = Self { + num_nodes, + states: vec![NodeState::Susceptible; num_nodes], + positions: vec![[0.0, 0.0]; num_nodes], + adjacency: vec![vec![]; num_nodes], + beta, + gamma, + }; + model.initialize_geometric_graph(); + model + } + + pub fn initialize_geometric_graph(&mut self) { + let mut rng = rand::thread_rng(); + self.states = vec![NodeState::Susceptible; self.num_nodes]; + self.positions = vec![[0.0, 0.0]; self.num_nodes]; + self.adjacency = vec![vec![]; self.num_nodes]; + + // Random geometric graph + let radius = 200.0; + for i in 0..self.num_nodes { + let angle = rng.r#gen_range(0.0..TAU); + let r = radius * rng.r#gen_range(0.0f32..1.0f32).sqrt(); + self.positions[i] = [r * angle.cos(), r * angle.sin()]; + } + + let connection_radius = 60.0; + for i in 0..self.num_nodes { + for j in (i + 1)..self.num_nodes { + let dx = self.positions[i][0] - self.positions[j][0]; + let dy = self.positions[i][1] - self.positions[j][1]; + let dist = (dx * dx + dy * dy).sqrt(); + if dist < connection_radius { + self.adjacency[i].push(j); + self.adjacency[j].push(i); + } + } + } + + // Start with one infected + if self.num_nodes > 0 { + let start_idx = rng.r#gen_range(0..self.num_nodes); + self.states[start_idx] = NodeState::Infected; + } + } + + pub fn step(&mut self) { + let mut next_states = self.states.clone(); + let mut rng = rand::thread_rng(); + + for (i, next_state) in next_states.iter_mut().enumerate().take(self.num_nodes) { + match self.states[i] { + NodeState::Susceptible => { + // Check infected neighbors + let infected_neighbors = self.adjacency[i] + .iter() + .filter(|&&j| self.states[j] == NodeState::Infected) + .count(); + for _ in 0..infected_neighbors { + if rng.r#gen::() < self.beta { + *next_state = NodeState::Infected; + break; + } + } + } + NodeState::Infected => { + if rng.r#gen::() < self.gamma { + *next_state = NodeState::Recovered; + } + } + NodeState::Recovered => {} + } + } + self.states = next_states; + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/math_explorer_gui/src/tabs/ai/attention_maps.rs b/math_explorer_gui/src/tabs/ai/attention_maps.rs index e64210c1..9a5be3ef 100644 --- a/math_explorer_gui/src/tabs/ai/attention_maps.rs +++ b/math_explorer_gui/src/tabs/ai/attention_maps.rs @@ -29,12 +29,13 @@ impl Default for AttentionMapsTool { impl AttentionMapsTool { fn recalculate(&mut self) { - let (output, weights) = math_explorer::ai::transformer::attention::scaled_dot_product_attention( - &self.q_matrix, - &self.k_matrix, - &self.v_matrix, - None, - ); + let (output, weights) = + math_explorer::ai::transformer::attention::scaled_dot_product_attention( + &self.q_matrix, + &self.k_matrix, + &self.v_matrix, + None, + ); self.output_matrix = output; self.attention_weights = weights; } @@ -93,10 +94,8 @@ impl AttentionMapsTool { ((1.0 - normalized) * 200.0 + 55.0) as u8, ); - let (rect, _response) = ui.allocate_exact_size( - egui::vec2(40.0, 40.0), - egui::Sense::hover(), - ); + let (rect, _response) = + ui.allocate_exact_size(egui::vec2(40.0, 40.0), egui::Sense::hover()); ui.painter().rect_filled(rect, 2.0, color); @@ -135,13 +134,15 @@ impl AiTool for AttentionMapsTool { ui.vertical(|ui| { let mut changed = false; ui.group(|ui| { - changed |= Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix); + changed |= + Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix); }); ui.group(|ui| { changed |= Self::draw_matrix_input(ui, "Keys (K)", &mut self.k_matrix); }); ui.group(|ui| { - changed |= Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix); + changed |= + Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix); }); if changed { @@ -153,7 +154,11 @@ impl AiTool for AttentionMapsTool { ui.vertical(|ui| { ui.group(|ui| { - Self::draw_heatmap(ui, "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", &self.attention_weights); + Self::draw_heatmap( + ui, + "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", + &self.attention_weights, + ); }); ui.add_space(20.0); diff --git a/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs b/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs index 0583a6da..33f134ca 100644 --- a/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs +++ b/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs @@ -1,8 +1,6 @@ use super::BatteryDegradationTool; use eframe::egui; -use math_explorer::applied::battery_degradation::{ - Capacity, DepthOfDischarge, PowerLawModel, -}; +use math_explorer::applied::battery_degradation::{Capacity, DepthOfDischarge, PowerLawModel}; pub struct LifetimeEstimatorTool { target_capacity: f64, diff --git a/math_explorer_gui/src/tabs/epidemiology/mod.rs b/math_explorer_gui/src/tabs/epidemiology/mod.rs index 93e51dcf..e77e9c78 100644 --- a/math_explorer_gui/src/tabs/epidemiology/mod.rs +++ b/math_explorer_gui/src/tabs/epidemiology/mod.rs @@ -1,8 +1,10 @@ use crate::tabs::ExplorerTab; use eframe::egui; +pub mod network_propagation; pub mod sir; +use network_propagation::NetworkPropagationTool; use sir::SirTool; /// A trait for sub-tools within the Epidemiology tab. @@ -22,7 +24,10 @@ pub struct EpidemiologyTab { impl Default for EpidemiologyTab { fn default() -> Self { Self { - tools: vec![Box::new(SirTool::default())], + tools: vec![ + Box::new(SirTool::default()), + Box::new(NetworkPropagationTool::default()), + ], selected_tool_index: 0, } } diff --git a/math_explorer_gui/src/tabs/epidemiology/network_propagation.rs b/math_explorer_gui/src/tabs/epidemiology/network_propagation.rs new file mode 100644 index 00000000..0a2da12c --- /dev/null +++ b/math_explorer_gui/src/tabs/epidemiology/network_propagation.rs @@ -0,0 +1,141 @@ +use super::EpidemiologyTool; +use eframe::egui; +use egui::Color32; +use math_explorer::epidemiology::networks::{NetworkEpidemicModel, NodeState}; + +pub struct NetworkPropagationTool { + model: NetworkEpidemicModel, + is_running: bool, +} + +impl Default for NetworkPropagationTool { + fn default() -> Self { + Self { + model: NetworkEpidemicModel::new(50, 0.05, 0.02), + is_running: false, + } + } +} + +impl NetworkPropagationTool { + fn reset_network(&mut self) { + self.model.initialize_geometric_graph(); + self.is_running = false; + } +} + +impl EpidemiologyTool for NetworkPropagationTool { + fn name(&self) -> &'static str { + "Network Propagation" + } + + fn show(&mut self, ui: &mut egui::Ui) { + if self.is_running { + self.model.step(); + ui.ctx().request_repaint(); + } + + ui.horizontal(|ui| { + ui.vertical(|ui| { + ui.heading("Controls"); + ui.separator(); + + if ui + .button(if self.is_running { "Pause" } else { "Start" }) + .clicked() + { + self.is_running = !self.is_running; + } + if ui.button("Reset Network").clicked() { + self.reset_network(); + } + + ui.separator(); + let mut changed = false; + changed |= ui + .add(egui::Slider::new(&mut self.model.num_nodes, 10..=200).text("Nodes")) + .changed(); + if changed { + self.reset_network(); + } + ui.add( + egui::Slider::new(&mut self.model.beta, 0.0..=1.0).text("Transmission (beta)"), + ); + ui.add( + egui::Slider::new(&mut self.model.gamma, 0.0..=1.0).text("Recovery (gamma)"), + ); + + ui.separator(); + // Statistics + let s_count = self + .model + .states + .iter() + .filter(|&&s| s == NodeState::Susceptible) + .count(); + let i_count = self + .model + .states + .iter() + .filter(|&&s| s == NodeState::Infected) + .count(); + let r_count = self + .model + .states + .iter() + .filter(|&&s| s == NodeState::Recovered) + .count(); + + ui.label( + egui::RichText::new(format!("Susceptible: {}", s_count)).color(Color32::BLUE), + ); + ui.label(egui::RichText::new(format!("Infected: {}", i_count)).color(Color32::RED)); + ui.label( + egui::RichText::new(format!("Recovered: {}", r_count)).color(Color32::GREEN), + ); + }); + + ui.separator(); + + ui.vertical(|ui| { + ui.heading("Network Visualization"); + let (response, painter) = + ui.allocate_painter(ui.available_size(), egui::Sense::hover()); + let rect = response.rect; + let center = rect.center(); + + // Draw edges + for i in 0..self.model.num_nodes { + for &j in &self.model.adjacency[i] { + if i < j { + let p1 = center + + egui::vec2( + self.model.positions[i][0], + self.model.positions[i][1], + ); + let p2 = center + + egui::vec2( + self.model.positions[j][0], + self.model.positions[j][1], + ); + painter.line_segment([p1, p2], (1.0, Color32::from_gray(100))); + } + } + } + + // Draw nodes + for i in 0..self.model.num_nodes { + let p = + center + egui::vec2(self.model.positions[i][0], self.model.positions[i][1]); + let color = match self.model.states[i] { + NodeState::Susceptible => Color32::BLUE, + NodeState::Infected => Color32::RED, + NodeState::Recovered => Color32::GREEN, + }; + painter.circle_filled(p, 5.0, color); + painter.circle_stroke(p, 5.0, (1.0, Color32::WHITE)); + } + }); + }); + } +} diff --git a/todo_gui.md b/todo_gui.md index 122c3e52..31ffa8a8 100644 --- a/todo_gui.md +++ b/todo_gui.md @@ -61,7 +61,7 @@ This document outlines the roadmap for integrating the various modules of the `m * **Features:** * [x] **SIR/SEIR Models:** Time-series plots of Susceptible, Infected, Recovered populations. * [x] **Parameter Sliders:** Adjust transmission rate ($\beta$) and recovery rate ($\gamma$). - * [ ] **Network Propagation:** Graph visualization of disease spread through a population. + * [x] **Network Propagation:** Graph visualization of disease spread through a population. ### 2.3 Evolutionary Game Theory * **Module:** `applied::game_theory::evolutionary`