Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions math_explorer/src/epidemiology/networks.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use rand::Rng;
use std::f32::consts::TAU;

/// Calculates R0 for a heterogeneous network.
///
/// $R_0 = \frac{\beta}{\gamma} \frac{\langle k^2 \rangle - \langle k \rangle}{\langle k \rangle}$
Expand All @@ -15,6 +18,101 @@ pub fn heterogeneous_r0(beta: f64, gamma: f64, mean_degree: f64, degree_variance
(beta / gamma) * factor
}

#[derive(Clone, Copy, PartialEq, Debug)]
pub enum NodeState {
Susceptible,
Infected,
Recovered,
}

pub struct NetworkEpidemicModel {
pub num_nodes: usize,
pub states: Vec<NodeState>,
pub positions: Vec<[f32; 2]>,
pub adjacency: Vec<Vec<usize>>,
pub beta: f64,
pub gamma: f64,
}
Comment on lines +28 to +35
Comment on lines +21 to +35

impl NetworkEpidemicModel {
pub fn new(num_nodes: usize, beta: f64, gamma: f64) -> Self {
let mut model = Self {
num_nodes,
states: vec![NodeState::Susceptible; num_nodes],
positions: vec![[0.0, 0.0]; num_nodes],
adjacency: vec![vec![]; num_nodes],
beta,
gamma,
};
model.initialize_geometric_graph();
model
}
Comment on lines +37 to +49

pub fn initialize_geometric_graph(&mut self) {
let mut rng = rand::thread_rng();
self.states = vec![NodeState::Susceptible; self.num_nodes];
self.positions = vec![[0.0, 0.0]; self.num_nodes];
self.adjacency = vec![vec![]; self.num_nodes];

// Random geometric graph
let radius = 200.0;
for i in 0..self.num_nodes {
let angle = rng.r#gen_range(0.0..TAU);
let r = radius * rng.r#gen_range(0.0f32..1.0f32).sqrt();
self.positions[i] = [r * angle.cos(), r * angle.sin()];
}

let connection_radius = 60.0;
for i in 0..self.num_nodes {
for j in (i + 1)..self.num_nodes {
let dx = self.positions[i][0] - self.positions[j][0];
let dy = self.positions[i][1] - self.positions[j][1];
let dist = (dx * dx + dy * dy).sqrt();
if dist < connection_radius {
self.adjacency[i].push(j);
self.adjacency[j].push(i);
}
}
}
Comment on lines +65 to +76

// Start with one infected
if self.num_nodes > 0 {
let start_idx = rng.r#gen_range(0..self.num_nodes);
self.states[start_idx] = NodeState::Infected;
}
}

pub fn step(&mut self) {
let mut next_states = self.states.clone();
let mut rng = rand::thread_rng();

for (i, next_state) in next_states.iter_mut().enumerate().take(self.num_nodes) {
match self.states[i] {
NodeState::Susceptible => {
// Check infected neighbors
let infected_neighbors = self.adjacency[i]
.iter()
.filter(|&&j| self.states[j] == NodeState::Infected)
.count();
for _ in 0..infected_neighbors {
if rng.r#gen::<f64>() < self.beta {
*next_state = NodeState::Infected;
break;
}
}
}
NodeState::Infected => {
if rng.r#gen::<f64>() < self.gamma {
*next_state = NodeState::Recovered;
}
}
NodeState::Recovered => {}
}
}
self.states = next_states;
}
Comment on lines +51 to +113
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
31 changes: 18 additions & 13 deletions math_explorer_gui/src/tabs/ai/attention_maps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ impl Default for AttentionMapsTool {

impl AttentionMapsTool {
fn recalculate(&mut self) {
let (output, weights) = math_explorer::ai::transformer::attention::scaled_dot_product_attention(
&self.q_matrix,
&self.k_matrix,
&self.v_matrix,
None,
);
let (output, weights) =
math_explorer::ai::transformer::attention::scaled_dot_product_attention(
&self.q_matrix,
&self.k_matrix,
&self.v_matrix,
None,
);
self.output_matrix = output;
self.attention_weights = weights;
}
Expand Down Expand Up @@ -93,10 +94,8 @@ impl AttentionMapsTool {
((1.0 - normalized) * 200.0 + 55.0) as u8,
);

let (rect, _response) = ui.allocate_exact_size(
egui::vec2(40.0, 40.0),
egui::Sense::hover(),
);
let (rect, _response) =
ui.allocate_exact_size(egui::vec2(40.0, 40.0), egui::Sense::hover());

ui.painter().rect_filled(rect, 2.0, color);

Expand Down Expand Up @@ -135,13 +134,15 @@ impl AiTool for AttentionMapsTool {
ui.vertical(|ui| {
let mut changed = false;
ui.group(|ui| {
changed |= Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix);
changed |=
Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix);
});
ui.group(|ui| {
changed |= Self::draw_matrix_input(ui, "Keys (K)", &mut self.k_matrix);
});
ui.group(|ui| {
changed |= Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix);
changed |=
Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix);
});

if changed {
Expand All @@ -153,7 +154,11 @@ impl AiTool for AttentionMapsTool {

ui.vertical(|ui| {
ui.group(|ui| {
Self::draw_heatmap(ui, "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", &self.attention_weights);
Self::draw_heatmap(
ui,
"Attention Weights (softmax(Q * K^T / sqrt(d_k)))",
&self.attention_weights,
);
});

ui.add_space(20.0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::BatteryDegradationTool;
use eframe::egui;
use math_explorer::applied::battery_degradation::{
Capacity, DepthOfDischarge, PowerLawModel,
};
use math_explorer::applied::battery_degradation::{Capacity, DepthOfDischarge, PowerLawModel};

pub struct LifetimeEstimatorTool {
target_capacity: f64,
Expand Down
7 changes: 6 additions & 1 deletion math_explorer_gui/src/tabs/epidemiology/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::tabs::ExplorerTab;
use eframe::egui;

pub mod network_propagation;
pub mod sir;

use network_propagation::NetworkPropagationTool;
use sir::SirTool;

/// A trait for sub-tools within the Epidemiology tab.
Expand All @@ -22,7 +24,10 @@ pub struct EpidemiologyTab {
impl Default for EpidemiologyTab {
fn default() -> Self {
Self {
tools: vec![Box::new(SirTool::default())],
tools: vec![
Box::new(SirTool::default()),
Box::new(NetworkPropagationTool::default()),
],
selected_tool_index: 0,
}
}
Expand Down
141 changes: 141 additions & 0 deletions math_explorer_gui/src/tabs/epidemiology/network_propagation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use super::EpidemiologyTool;
use eframe::egui;
use egui::Color32;
use math_explorer::epidemiology::networks::{NetworkEpidemicModel, NodeState};

pub struct NetworkPropagationTool {
model: NetworkEpidemicModel,
is_running: bool,
}

impl Default for NetworkPropagationTool {
fn default() -> Self {
Self {
model: NetworkEpidemicModel::new(50, 0.05, 0.02),
is_running: false,
}
}
}

impl NetworkPropagationTool {
fn reset_network(&mut self) {
self.model.initialize_geometric_graph();
self.is_running = false;
}
}

impl EpidemiologyTool for NetworkPropagationTool {
fn name(&self) -> &'static str {
"Network Propagation"
}

fn show(&mut self, ui: &mut egui::Ui) {
if self.is_running {
self.model.step();
ui.ctx().request_repaint();
}

ui.horizontal(|ui| {
ui.vertical(|ui| {
ui.heading("Controls");
ui.separator();

if ui
.button(if self.is_running { "Pause" } else { "Start" })
.clicked()
{
self.is_running = !self.is_running;
}
if ui.button("Reset Network").clicked() {
self.reset_network();
}

ui.separator();
let mut changed = false;
changed |= ui
.add(egui::Slider::new(&mut self.model.num_nodes, 10..=200).text("Nodes"))
.changed();
if changed {
self.reset_network();
}
ui.add(
egui::Slider::new(&mut self.model.beta, 0.0..=1.0).text("Transmission (beta)"),
);
ui.add(
egui::Slider::new(&mut self.model.gamma, 0.0..=1.0).text("Recovery (gamma)"),
);

ui.separator();
// Statistics
let s_count = self
.model
.states
.iter()
.filter(|&&s| s == NodeState::Susceptible)
.count();
let i_count = self
.model
.states
.iter()
.filter(|&&s| s == NodeState::Infected)
.count();
let r_count = self
.model
.states
.iter()
.filter(|&&s| s == NodeState::Recovered)
.count();

ui.label(
egui::RichText::new(format!("Susceptible: {}", s_count)).color(Color32::BLUE),
);
ui.label(egui::RichText::new(format!("Infected: {}", i_count)).color(Color32::RED));
ui.label(
egui::RichText::new(format!("Recovered: {}", r_count)).color(Color32::GREEN),
);
});

ui.separator();

ui.vertical(|ui| {
ui.heading("Network Visualization");
let (response, painter) =
ui.allocate_painter(ui.available_size(), egui::Sense::hover());
let rect = response.rect;
let center = rect.center();

// Draw edges
for i in 0..self.model.num_nodes {
for &j in &self.model.adjacency[i] {
if i < j {
let p1 = center
+ egui::vec2(
self.model.positions[i][0],
self.model.positions[i][1],
);
let p2 = center
+ egui::vec2(
self.model.positions[j][0],
self.model.positions[j][1],
);
painter.line_segment([p1, p2], (1.0, Color32::from_gray(100)));
}
}
}

// Draw nodes
for i in 0..self.model.num_nodes {
let p =
center + egui::vec2(self.model.positions[i][0], self.model.positions[i][1]);
let color = match self.model.states[i] {
NodeState::Susceptible => Color32::BLUE,
NodeState::Infected => Color32::RED,
NodeState::Recovered => Color32::GREEN,
};
painter.circle_filled(p, 5.0, color);
painter.circle_stroke(p, 5.0, (1.0, Color32::WHITE));
}
});
});
}
}
2 changes: 1 addition & 1 deletion todo_gui.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ This document outlines the roadmap for integrating the various modules of the `m
* **Features:**
* [x] **SIR/SEIR Models:** Time-series plots of Susceptible, Infected, Recovered populations.
* [x] **Parameter Sliders:** Adjust transmission rate ($\beta$) and recovery rate ($\gamma$).
* [ ] **Network Propagation:** Graph visualization of disease spread through a population.
* [x] **Network Propagation:** Graph visualization of disease spread through a population.

### 2.3 Evolutionary Game Theory
* **Module:** `applied::game_theory::evolutionary`
Expand Down
Loading