From 290a2419576a7fdaaa5368940de1ecbdf2350879 Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Mon, 16 Mar 2026 12:52:53 +0000 Subject: [PATCH] feat: Add Grid World AI tool to math_explorer_gui Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- .../src/tabs/ai/attention_maps.rs | 31 ++- math_explorer_gui/src/tabs/ai/grid_world.rs | 261 ++++++++++++++++++ math_explorer_gui/src/tabs/ai/mod.rs | 2 + .../battery_degradation/lifetime_estimator.rs | 4 +- todo_gui.md | 2 +- 5 files changed, 283 insertions(+), 17 deletions(-) create mode 100644 math_explorer_gui/src/tabs/ai/grid_world.rs diff --git a/math_explorer_gui/src/tabs/ai/attention_maps.rs b/math_explorer_gui/src/tabs/ai/attention_maps.rs index e64210c1..9a5be3ef 100644 --- a/math_explorer_gui/src/tabs/ai/attention_maps.rs +++ b/math_explorer_gui/src/tabs/ai/attention_maps.rs @@ -29,12 +29,13 @@ impl Default for AttentionMapsTool { impl AttentionMapsTool { fn recalculate(&mut self) { - let (output, weights) = math_explorer::ai::transformer::attention::scaled_dot_product_attention( - &self.q_matrix, - &self.k_matrix, - &self.v_matrix, - None, - ); + let (output, weights) = + math_explorer::ai::transformer::attention::scaled_dot_product_attention( + &self.q_matrix, + &self.k_matrix, + &self.v_matrix, + None, + ); self.output_matrix = output; self.attention_weights = weights; } @@ -93,10 +94,8 @@ impl AttentionMapsTool { ((1.0 - normalized) * 200.0 + 55.0) as u8, ); - let (rect, _response) = ui.allocate_exact_size( - egui::vec2(40.0, 40.0), - egui::Sense::hover(), - ); + let (rect, _response) = + ui.allocate_exact_size(egui::vec2(40.0, 40.0), egui::Sense::hover()); ui.painter().rect_filled(rect, 2.0, color); @@ -135,13 +134,15 @@ impl AiTool for AttentionMapsTool { ui.vertical(|ui| { let mut changed = false; ui.group(|ui| { - changed |= Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix); + changed |= + Self::draw_matrix_input(ui, "Queries (Q)", &mut self.q_matrix); }); ui.group(|ui| { changed |= Self::draw_matrix_input(ui, "Keys (K)", &mut self.k_matrix); }); ui.group(|ui| { - changed |= Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix); + changed |= + Self::draw_matrix_input(ui, "Values (V)", &mut self.v_matrix); }); if changed { @@ -153,7 +154,11 @@ impl AiTool for AttentionMapsTool { ui.vertical(|ui| { ui.group(|ui| { - Self::draw_heatmap(ui, "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", &self.attention_weights); + Self::draw_heatmap( + ui, + "Attention Weights (softmax(Q * K^T / sqrt(d_k)))", + &self.attention_weights, + ); }); ui.add_space(20.0); diff --git a/math_explorer_gui/src/tabs/ai/grid_world.rs b/math_explorer_gui/src/tabs/ai/grid_world.rs new file mode 100644 index 00000000..7648fa83 --- /dev/null +++ b/math_explorer_gui/src/tabs/ai/grid_world.rs @@ -0,0 +1,261 @@ +use crate::tabs::ai::AiTool; +use eframe::egui; +use math_explorer::ai::reinforcement_learning::{ + algorithms::TabularQAgent, Action, MarkovDecisionProcess, State, +}; +use std::hash::Hash; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct GridState { + pub x: i32, + pub y: i32, +} + +impl State for GridState {} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Move { + Up, + Down, + Left, + Right, +} + +impl Action for Move {} + +pub struct GridWorldEnv { + pub width: i32, + pub height: i32, + pub goal: GridState, + pub start: GridState, + pub traps: Vec, + pub gamma: f64, +} + +impl MarkovDecisionProcess for GridWorldEnv { + type S = GridState; + type A = Move; + + fn transition_probability( + &self, + next_state: &Self::S, + current_state: &Self::S, + action: &Self::A, + ) -> f64 { + // Simplified deterministic transition for the tool + let mut expected_next = *current_state; + match action { + Move::Up => expected_next.y -= 1, + Move::Down => expected_next.y += 1, + Move::Left => expected_next.x -= 1, + Move::Right => expected_next.x += 1, + } + + let is_valid = expected_next.x >= 0 + && expected_next.x < self.width + && expected_next.y >= 0 + && expected_next.y < self.height; + + let actual_next = if is_valid { + expected_next + } else { + *current_state + }; + + if *next_state == actual_next { + 1.0 + } else { + 0.0 + } + } + + fn reward(&self, _current_state: &Self::S, _action: &Self::A, next_state: &Self::S) -> f64 { + if *next_state == self.goal { + 10.0 + } else if self.traps.contains(next_state) { + -10.0 + } else { + -0.1 + } + } + + fn actions(&self, state: &Self::S) -> Vec { + if self.is_terminal(state) { + vec![] + } else { + vec![Move::Up, Move::Down, Move::Left, Move::Right] + } + } + + fn discount_factor(&self) -> f64 { + self.gamma + } + + fn is_terminal(&self, state: &Self::S) -> bool { + *state == self.goal || self.traps.contains(state) + } +} + +pub struct GridWorldTool { + env: GridWorldEnv, + agent: TabularQAgent, + current_state: GridState, + episodes: u32, + total_reward: f64, + steps: u32, +} + +impl Default for GridWorldTool { + fn default() -> Self { + let env = GridWorldEnv { + width: 5, + height: 5, + start: GridState { x: 0, y: 0 }, + goal: GridState { x: 4, y: 4 }, + traps: vec![GridState { x: 2, y: 2 }, GridState { x: 3, y: 2 }], + gamma: 0.9, + }; + let agent = TabularQAgent::new(0.1, 0.9, 0.1); + Self { + current_state: env.start, + env, + agent, + episodes: 0, + total_reward: 0.0, + steps: 0, + } + } +} + +impl GridWorldTool { + fn step_agent(&mut self) { + if self.env.is_terminal(&self.current_state) { + self.reset_episode(); + return; + } + + let actions = self.env.actions(&self.current_state); + if actions.is_empty() { + return; + } + + if let Some(action) = self.agent.select_action(&self.current_state, &actions) { + let mut expected_next = self.current_state; + match action { + Move::Up => expected_next.y -= 1, + Move::Down => expected_next.y += 1, + Move::Left => expected_next.x -= 1, + Move::Right => expected_next.x += 1, + } + + let is_valid = expected_next.x >= 0 + && expected_next.x < self.env.width + && expected_next.y >= 0 + && expected_next.y < self.env.height; + + let next_state = if is_valid { + expected_next + } else { + self.current_state + }; + + let reward = self.env.reward(&self.current_state, &action, &next_state); + let next_actions = self.env.actions(&next_state); + + self.agent.update( + &self.current_state, + &action, + reward, + &next_state, + &next_actions, + ); + + self.current_state = next_state; + self.total_reward += reward; + self.steps += 1; + } + } + + fn reset_episode(&mut self) { + self.current_state = self.env.start; + self.episodes += 1; + self.total_reward = 0.0; + self.steps = 0; + } +} + +impl AiTool for GridWorldTool { + fn name(&self) -> &'static str { + "Grid World (RL)" + } + + fn show(&mut self, ctx: &egui::Context) { + egui::Window::new("Grid World Navigation (Q-Learning)").show(ctx, |ui| { + ui.horizontal(|ui| { + if ui.button("Step").clicked() { + self.step_agent(); + } + if ui.button("Train (100 Episodes)").clicked() { + for _ in 0..100 { + let mut temp_steps = 0; + while !self.env.is_terminal(&self.current_state) && temp_steps < 100 { + self.step_agent(); + temp_steps += 1; + } + self.reset_episode(); + } + } + if ui.button("Reset Agent").clicked() { + self.agent = TabularQAgent::new(0.1, 0.9, 0.1); + self.reset_episode(); + self.episodes = 0; + } + }); + + ui.horizontal(|ui| { + ui.label(format!("Episode: {}", self.episodes)); + ui.label(format!("Steps: {}", self.steps)); + ui.label(format!("Reward: {:.2}", self.total_reward)); + }); + + let cell_size = 40.0; + let grid_size = egui::vec2( + self.env.width as f32 * cell_size, + self.env.height as f32 * cell_size, + ); + + let (response, painter) = ui.allocate_painter(grid_size, egui::Sense::hover()); + let rect = response.rect; + + for x in 0..self.env.width { + for y in 0..self.env.height { + let top_left = + rect.min + egui::vec2(x as f32 * cell_size, y as f32 * cell_size); + let cell_rect = + egui::Rect::from_min_size(top_left, egui::vec2(cell_size, cell_size)); + + let state = GridState { x, y }; + let mut fill_color = egui::Color32::from_gray(200); + + if state == self.env.goal { + fill_color = egui::Color32::GREEN; + } else if self.env.traps.contains(&state) { + fill_color = egui::Color32::RED; + } else if state == self.current_state { + fill_color = egui::Color32::BLUE; + } else if state == self.env.start { + fill_color = egui::Color32::LIGHT_BLUE; + } + + painter.rect_filled(cell_rect, 0.0, fill_color); + painter.rect_stroke( + cell_rect, + 0.0, + egui::Stroke::new(1.0, egui::Color32::BLACK), + egui::StrokeKind::Middle, + ); + } + } + }); + } +} diff --git a/math_explorer_gui/src/tabs/ai/mod.rs b/math_explorer_gui/src/tabs/ai/mod.rs index da8b1c53..f8dee368 100644 --- a/math_explorer_gui/src/tabs/ai/mod.rs +++ b/math_explorer_gui/src/tabs/ai/mod.rs @@ -3,6 +3,7 @@ use eframe::egui; pub mod activation_functions; pub mod attention_maps; +pub mod grid_world; pub mod loss_landscape; pub mod training_monitor; @@ -28,6 +29,7 @@ impl Default for AiTab { Box::new(training_monitor::TrainingMonitorTool::default()), Box::new(activation_functions::ActivationFunctionsTool::default()), Box::new(attention_maps::AttentionMapsTool::default()), + Box::new(grid_world::GridWorldTool::default()), ], selected_tool_index: 0, } diff --git a/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs b/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs index 0583a6da..33f134ca 100644 --- a/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs +++ b/math_explorer_gui/src/tabs/battery_degradation/lifetime_estimator.rs @@ -1,8 +1,6 @@ use super::BatteryDegradationTool; use eframe::egui; -use math_explorer::applied::battery_degradation::{ - Capacity, DepthOfDischarge, PowerLawModel, -}; +use math_explorer::applied::battery_degradation::{Capacity, DepthOfDischarge, PowerLawModel}; pub struct LifetimeEstimatorTool { target_capacity: f64, diff --git a/todo_gui.md b/todo_gui.md index 122c3e52..5ffaec53 100644 --- a/todo_gui.md +++ b/todo_gui.md @@ -89,7 +89,7 @@ This document outlines the roadmap for integrating the various modules of the `m ### 3.2 Reinforcement Learning * **Module:** `ai::reinforcement_learning` * **Features:** - * [ ] **Grid World:** Agent navigation visualization. + * [x] **Grid World:** Agent navigation visualization. * [ ] **Q-Table Inspector:** Heatmap of Q-values. * [ ] **Reward Plots:** Cumulative reward over episodes.