From fde7ed9c855d4e3b469a71864872920cb9f1fd25 Mon Sep 17 00:00:00 2001 From: Dan Bravender Date: Mon, 19 Feb 2024 16:09:51 -0500 Subject: [PATCH] Save and load neural networks --- .gitignore | 4 ++++ Cargo.lock | 31 +++++++++++++++++++++++++++++++ Cargo.toml | 2 ++ src/agent.rs | 18 ++++++++++++++++++ src/editor.rs | 10 ++++++++++ src/main.rs | 8 ++++++++ src/nn.rs | 19 +++++++++++++++++-- src/simulation.rs | 18 ++++++++++++++++++ 8 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0f5b095 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.DS_Store +/.vscode +/target +output.brain diff --git a/Cargo.lock b/Cargo.lock index 5dd5271..2719647 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -268,10 +277,12 @@ dependencies = [ name = "flappy-ai" version = "0.1.0" dependencies = [ + "bincode", "egui-macroquad", "macroquad", "once_cell", "rand", + "serde", ] [[package]] @@ -793,6 +804,26 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "serde" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + [[package]] name = "simd-adler32" version = "0.3.5" diff --git a/Cargo.toml b/Cargo.toml index 5bc7b74..6bc7275 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,5 @@ egui-macroquad = "0.15.0" macroquad = "0.3.25" once_cell = "1.17.1" rand = "0.8.5" +serde = { version = "1.0", features = ["derive"] } +bincode = "1.3.3" diff --git a/src/agent.rs b/src/agent.rs index d995fba..f27868b 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -30,6 +30,12 @@ impl Bird { } } + pub fn with_net(net: &Net) -> Self { + let mut new_bird = Bird::new(); + new_bird.brain = net.clone(); + new_bird + } + pub fn with_brain(other: &Bird) -> Self { let mut new_bird = Bird::new(); new_bird.brain = other.brain.clone(); @@ -107,6 +113,18 @@ impl Bird { self.brain.mutate(); } + pub fn save(&self) { + self.brain + .save_to_disk("output.brain") + .expect("Failed to save brain file"); + } + + pub fn load(&self) -> Net { + self.brain + .load_from_disk("output.brain") + .expect("Failed to load brain") + } + fn mark_dead(&mut self) { self.is_dead = true; } diff --git a/src/editor.rs b/src/editor.rs index 46983f6..d003ac5 100644 --- a/src/editor.rs +++ b/src/editor.rs @@ -10,6 +10,8 @@ pub struct Settings { pub is_frame_skip: bool, pub is_show_egui: bool, pub show_one_bird: bool, + pub save: bool, + pub load: bool, } pub struct Editor { @@ -25,6 +27,8 @@ impl Settings { is_frame_skip: false, is_show_egui: true, show_one_bird: false, + save: false, + load: false, } } } @@ -83,6 +87,12 @@ impl Editor { if ui.add(egui::Button::new("Restart")).clicked() { self.settings.is_restart = true; } + if ui.add(egui::Button::new("Save")).clicked() { + self.settings.save = true; + } + if ui.add(egui::Button::new("Load")).clicked() { + self.settings.load = true; + } }); }); }); diff --git a/src/main.rs b/src/main.rs index 4989f25..06ed1b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,14 @@ async fn main() { stats = simulation.update(&editor.settings).unwrap_or(stats); } } + if editor.settings.save { + editor.settings.save = false; + simulation.save(); + } + if editor.settings.load { + editor.settings.load = false; + simulation.load(); + } if is_key_pressed(KeyCode::Q) { break; diff --git a/src/nn.rs b/src/nn.rs index 98863e1..a970459 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -1,14 +1,17 @@ use macroquad::rand::gen_range; +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::{self, BufReader, BufWriter}; use crate::*; -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct Net { n_inputs: usize, layers: Vec, } -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] struct Layer { nodes: Vec>, } @@ -57,6 +60,18 @@ impl Net { pub fn mutate(&mut self) { self.layers.iter_mut().for_each(|l| l.mutate()); } + + pub fn save_to_disk(&self, filename: &str) -> io::Result<()> { + let file = File::create(filename)?; + let writer = BufWriter::new(file); + bincode::serialize_into(writer, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + + pub fn load_from_disk(&self, filename: &str) -> io::Result { + let file = File::open(filename)?; + let reader = BufReader::new(file); + bincode::deserialize_from(reader).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } } impl Layer { diff --git a/src/simulation.rs b/src/simulation.rs index 98a1b04..e425d94 100644 --- a/src/simulation.rs +++ b/src/simulation.rs @@ -68,6 +68,24 @@ impl Simulation { self.birds.iter().for_each(|b| b.draw()); } + pub fn save(&self) { + let bird = self.birds.iter().filter(|b| !b.is_dead).next().unwrap(); + bird.save(); + } + + pub fn load(&mut self) -> () { + let bird = self.birds.first().unwrap(); + let net = bird.load(); + let mut new_birds = Vec::new(); + + for _ in 0..NUM_BIRDS { + let mut new_bird = Bird::with_net(&net); + new_bird.mutate(); + new_birds.push(new_bird); + } + self.birds = new_birds; + } + fn selection(&self) -> Vec { let mut rng = thread_rng(); let gene_pool = self.calc_fitness();