Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions math_explorer_gui/src/tabs/ai/attention_maps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ impl Default for AttentionMapsTool {

impl AttentionMapsTool {
fn recalculate(&mut self) {
let (output, weights) = math_explorer::ai::transformer::attention::scaled_dot_product_attention(
&self.q_matrix,
&self.k_matrix,
&self.v_matrix,
None,
);
let (output, weights) =
math_explorer::ai::transformer::attention::scaled_dot_product_attention(
&self.q_matrix,
&self.k_matrix,
&self.v_matrix,
None,
);
self.output_matrix = output;
self.attention_weights = weights;
}
Expand Down Expand Up @@ -93,10 +94,8 @@ impl AttentionMapsTool {
((1.0 - normalized) * 200.0 + 55.0) as u8,
);

let (rect, _response) = ui.allocate_exact_size(
egui::vec2(40.0, 40.0),
egui::Sense::hover(),
);
let (rect, _response) =
ui.allocate_exact_size(egui::vec2(40.0, 40.0), egui::Sense::hover());

ui.painter().rect_filled(rect, 2.0, color);

Expand Down Expand Up @@ -135,13 +134,15 @@ impl AiTool for AttentionMapsTool {
ui.vertical(|ui| {
let mut changed = false;
ui.group(|ui| {
changed |= Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix);
changed |=
Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix);
});
ui.group(|ui| {
changed |= Self::draw_matrix_input(ui, "Keys (K)", &mut self.k_matrix);
});
ui.group(|ui| {
changed |= Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix);
changed |=
Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix);
});

if changed {
Expand All @@ -153,7 +154,11 @@ impl AiTool for AttentionMapsTool {

ui.vertical(|ui| {
ui.group(|ui| {
Self::draw_heatmap(ui, "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", &self.attention_weights);
Self::draw_heatmap(
ui,
"Attention Weights (softmax(Q * K^T / sqrt(d_k)))",
&self.attention_weights,
);
});

ui.add_space(20.0);
Expand Down
261 changes: 261 additions & 0 deletions math_explorer_gui/src/tabs/ai/grid_world.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
use crate::tabs::ai::AiTool;
use eframe::egui;
use math_explorer::ai::reinforcement_learning::{
algorithms::TabularQAgent, Action, MarkovDecisionProcess, State,
};
use std::hash::Hash;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct GridState {
pub x: i32,
pub y: i32,
}

impl State for GridState {}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Move {
Up,
Down,
Left,
Right,
}

impl Action for Move {}

pub struct GridWorldEnv {
pub width: i32,
pub height: i32,
pub goal: GridState,
pub start: GridState,
pub traps: Vec<GridState>,
pub gamma: f64,
}

impl MarkovDecisionProcess for GridWorldEnv {
type S = GridState;
type A = Move;

fn transition_probability(
&self,
next_state: &Self::S,
current_state: &Self::S,
action: &Self::A,
) -> f64 {
// Simplified deterministic transition for the tool
let mut expected_next = *current_state;
match action {
Move::Up => expected_next.y -= 1,
Move::Down => expected_next.y += 1,
Move::Left => expected_next.x -= 1,
Move::Right => expected_next.x += 1,
}

let is_valid = expected_next.x >= 0
&& expected_next.x < self.width
&& expected_next.y >= 0
&& expected_next.y < self.height;

let actual_next = if is_valid {
expected_next
} else {
*current_state
};

if *next_state == actual_next {
1.0
} else {
0.0
}
}

fn reward(&self, _current_state: &Self::S, _action: &Self::A, next_state: &Self::S) -> f64 {
if *next_state == self.goal {
10.0
} else if self.traps.contains(next_state) {
-10.0
} else {
-0.1
}
}

fn actions(&self, state: &Self::S) -> Vec<Self::A> {
if self.is_terminal(state) {
vec![]
} else {
vec![Move::Up, Move::Down, Move::Left, Move::Right]
}
}

fn discount_factor(&self) -> f64 {
self.gamma
}

fn is_terminal(&self, state: &Self::S) -> bool {
*state == self.goal || self.traps.contains(state)
}
}

pub struct GridWorldTool {
env: GridWorldEnv,
agent: TabularQAgent<GridState, Move>,
current_state: GridState,
episodes: u32,
total_reward: f64,
steps: u32,
}

impl Default for GridWorldTool {
fn default() -> Self {
let env = GridWorldEnv {
width: 5,
height: 5,
start: GridState { x: 0, y: 0 },
goal: GridState { x: 4, y: 4 },
traps: vec![GridState { x: 2, y: 2 }, GridState { x: 3, y: 2 }],
gamma: 0.9,
};
let agent = TabularQAgent::new(0.1, 0.9, 0.1);
Self {
current_state: env.start,
env,
agent,
episodes: 0,
total_reward: 0.0,
steps: 0,
}
}
}

impl GridWorldTool {
fn step_agent(&mut self) {
if self.env.is_terminal(&self.current_state) {
self.reset_episode();
return;
}

let actions = self.env.actions(&self.current_state);
if actions.is_empty() {
return;
}

if let Some(action) = self.agent.select_action(&self.current_state, &actions) {
let mut expected_next = self.current_state;
Comment on lines +118 to +143
match action {
Move::Up => expected_next.y -= 1,
Move::Down => expected_next.y += 1,
Move::Left => expected_next.x -= 1,
Move::Right => expected_next.x += 1,
}

let is_valid = expected_next.x >= 0
&& expected_next.x < self.env.width
&& expected_next.y >= 0
&& expected_next.y < self.env.height;

let next_state = if is_valid {
expected_next
} else {
self.current_state
};

Comment on lines +143 to +161
let reward = self.env.reward(&self.current_state, &action, &next_state);
let next_actions = self.env.actions(&next_state);

self.agent.update(
&self.current_state,
&action,
reward,
&next_state,
&next_actions,
);

self.current_state = next_state;
self.total_reward += reward;
self.steps += 1;
}
}

fn reset_episode(&mut self) {
self.current_state = self.env.start;
self.episodes += 1;
self.total_reward = 0.0;
self.steps = 0;
}
}

impl AiTool for GridWorldTool {
fn name(&self) -> &'static str {
"Grid World (RL)"
}

fn show(&mut self, ctx: &egui::Context) {
egui::Window::new("Grid World Navigation (Q-Learning)").show(ctx, |ui| {
ui.horizontal(|ui| {
if ui.button("Step").clicked() {
self.step_agent();
}
if ui.button("Train (100 Episodes)").clicked() {
for _ in 0..100 {
let mut temp_steps = 0;
while !self.env.is_terminal(&self.current_state) && temp_steps < 100 {
self.step_agent();
temp_steps += 1;
}
self.reset_episode();
}
}
Comment on lines +198 to +207
if ui.button("Reset Agent").clicked() {
self.agent = TabularQAgent::new(0.1, 0.9, 0.1);
self.reset_episode();
self.episodes = 0;
}
});

ui.horizontal(|ui| {
ui.label(format!("Episode: {}", self.episodes));
ui.label(format!("Steps: {}", self.steps));
ui.label(format!("Reward: {:.2}", self.total_reward));
});

let cell_size = 40.0;
let grid_size = egui::vec2(
self.env.width as f32 * cell_size,
self.env.height as f32 * cell_size,
);

let (response, painter) = ui.allocate_painter(grid_size, egui::Sense::hover());
let rect = response.rect;

for x in 0..self.env.width {
for y in 0..self.env.height {
let top_left =
rect.min + egui::vec2(x as f32 * cell_size, y as f32 * cell_size);
let cell_rect =
egui::Rect::from_min_size(top_left, egui::vec2(cell_size, cell_size));

let state = GridState { x, y };
let mut fill_color = egui::Color32::from_gray(200);

if state == self.env.goal {
fill_color = egui::Color32::GREEN;
} else if self.env.traps.contains(&state) {
fill_color = egui::Color32::RED;
} else if state == self.current_state {
fill_color = egui::Color32::BLUE;
} else if state == self.env.start {
fill_color = egui::Color32::LIGHT_BLUE;
}

painter.rect_filled(cell_rect, 0.0, fill_color);
painter.rect_stroke(
cell_rect,
0.0,
egui::Stroke::new(1.0, egui::Color32::BLACK),
egui::StrokeKind::Middle,
);
}
}
});
}
}
2 changes: 2 additions & 0 deletions math_explorer_gui/src/tabs/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use eframe::egui;

pub mod activation_functions;
pub mod attention_maps;
pub mod grid_world;
pub mod loss_landscape;
pub mod training_monitor;

Expand All @@ -28,6 +29,7 @@ impl Default for AiTab {
Box::new(training_monitor::TrainingMonitorTool::default()),
Box::new(activation_functions::ActivationFunctionsTool::default()),
Box::new(attention_maps::AttentionMapsTool::default()),
Box::new(grid_world::GridWorldTool::default()),
],
selected_tool_index: 0,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::BatteryDegradationTool;
use eframe::egui;
use math_explorer::applied::battery_degradation::{
Capacity, DepthOfDischarge, PowerLawModel,
};
use math_explorer::applied::battery_degradation::{Capacity, DepthOfDischarge, PowerLawModel};

pub struct LifetimeEstimatorTool {
target_capacity: f64,
Expand Down
2 changes: 1 addition & 1 deletion todo_gui.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ This document outlines the roadmap for integrating the various modules of the `m
### 3.2 Reinforcement Learning
* **Module:** `ai::reinforcement_learning`
* **Features:**
* [ ] **Grid World:** Agent navigation visualization.
* [x] **Grid World:** Agent navigation visualization.
* [ ] **Q-Table Inspector:** Heatmap of Q-values.
* [ ] **Reward Plots:** Cumulative reward over episodes.

Expand Down
Loading