From 758453066076b96b76d8ee0289083b0535740ce7 Mon Sep 17 00:00:00 2001 From: drehwald Date: Fri, 11 Dec 2020 04:45:44 +0100 Subject: [PATCH 1/4] starting transfer to spaces --- Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 57cbf39..8efa4c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,10 @@ ndarray-stats = "0.3.0" default = [] download = ["datasets/download"] +[dependencies.spaces] +git = "https://github.com/tspooner/spaces" +branch = "master" + [dependencies.datasets] git = "https://github.com/ZuseZ4/Datasets" branch = "master" From f9fe68a48b220adb4afc0bfe9eeae0b488641690 Mon Sep 17 00:00:00 2001 From: drehwald Date: Sat, 12 Dec 2020 03:42:10 +0100 Subject: [PATCH 2/4] add spaces to Environment --- src/rl/env/env_trait.rs | 7 +++++++ src/rl/env/fortress.rs | 4 ++++ src/rl/env/tictactoe.rs | 4 ++++ src/rl/training/trainer.rs | 16 ++++++++++++---- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/rl/env/env_trait.rs b/src/rl/env/env_trait.rs index 34cbe2a..b61f89c 100644 --- a/src/rl/env/env_trait.rs +++ b/src/rl/env/env_trait.rs @@ -1,7 +1,14 @@ use ndarray::{Array1, Array2}; +use spaces::Space; /// This trait defines all functions on which agents and other user might depend. pub trait Environment { + /// State representation + type StateSpace: Space; + + /// Action space representation + type ActionSpace: Space; + /// The central function which causes the environment to pass various information to the agent. /// /// The Array2 encodes the environment (the board). diff --git a/src/rl/env/fortress.rs b/src/rl/env/fortress.rs index ff49f87..abb7017 100644 --- a/src/rl/env/fortress.rs +++ b/src/rl/env/fortress.rs @@ -1,5 +1,7 @@ use crate::rl::env::env_trait::Environment; use ndarray::{Array, Array1, Array2}; +use spaces::discrete::Ordinal; +use spaces::*; use std::cmp::Ordering; static NEIGHBOURS_LIST: [&[usize]; 6 * 6] = [ @@ -54,6 +56,8 @@ pub struct Fortress { } impl Environment for Fortress { + type StateSpace = ProductSpace; + type ActionSpace = Ordinal; fn step(&self) -> (Array2, Array1, f32, bool) { if !self.active { eprintln!("Warning, calling step() after done = true!"); diff --git a/src/rl/env/tictactoe.rs b/src/rl/env/tictactoe.rs index 1f8c67f..32ae530 100644 --- a/src/rl/env/tictactoe.rs +++ b/src/rl/env/tictactoe.rs @@ -1,5 +1,7 @@ use crate::rl::env::env_trait::Environment; use ndarray::{Array, Array1, Array2}; +use spaces::discrete::Ordinal; +use spaces::*; static BITMASKS: [&[u16]; 9] = [ //TODO FIX: BROKEN @@ -39,6 +41,8 @@ impl Default for TicTacToe { } impl Environment for TicTacToe { + type StateSpace = ProductSpace; + type ActionSpace = Ordinal; fn step(&self) -> (Array2, Array1, f32, bool) { // storing current position into ndarray let position = board_as_arr(self.player1, self.player2) diff --git a/src/rl/training/trainer.rs b/src/rl/training/trainer.rs index a3bae1a..0dff128 100644 --- a/src/rl/training/trainer.rs +++ b/src/rl/training/trainer.rs @@ -1,16 +1,24 @@ use crate::rl::agent::Agent; use crate::rl::env::Environment; use ndarray::Array2; +use spaces::Space; /// A trainer works on a given environment and a set of agents. -pub struct Trainer { - env: Box, +pub struct Trainer +where + S: Space, + A: Space, +{ + env: Box>, res: Vec<(u32, u32, u32)>, agents: Vec>, } -impl Trainer { +impl Trainer { /// We construct a Trainer by passing a single environment and one or more (possibly different) agents. - pub fn new(env: Box, agents: Vec>) -> Result { + pub fn new( + env: Box>, + agents: Vec>, + ) -> Result { if agents.is_empty() { return Err("At least one agent required!".to_string()); } From e01dd8c8fa7db11b2cb6919d6d48bec9c4b39713 Mon Sep 17 00:00:00 2001 From: drehwald Date: Sat, 12 Dec 2020 04:01:55 +0100 Subject: [PATCH 3/4] add spaces to Agents --- examples/fortress.rs | 31 ++++++++++++++++++------------- examples/tictactoe.rs | 31 ++++++++++++++++++------------- src/rl/agent/agent_trait.rs | 7 ++++++- src/rl/agent/dql_agent.rs | 3 ++- src/rl/agent/human_player.rs | 3 ++- src/rl/agent/ql_agent.rs | 3 ++- src/rl/agent/random_agent.rs | 3 ++- src/rl/training/trainer.rs | 4 ++-- 8 files changed, 52 insertions(+), 33 deletions(-) diff --git a/examples/fortress.rs b/examples/fortress.rs index 0bd57b1..4f3c844 100644 --- a/examples/fortress.rs +++ b/examples/fortress.rs @@ -2,6 +2,8 @@ use agent::*; use env::Fortress; use rust_rl::network::nn::NeuralNetwork; use rust_rl::rl::{agent, env, training}; +use spaces::discrete::Ordinal; +use spaces::*; use std::io; use training::{utils, Trainer}; @@ -62,21 +64,24 @@ pub fn main() { ); } -fn get_agents(agent_nums: Vec) -> Result>, String> { - let mut res: Vec> = vec![]; +fn get_agents( + agent_nums: Vec, +) -> Result, Ordinal>>>, String> { + let mut res: Vec, Ordinal>>> = vec![]; let batch_size = 16; for agent_num in agent_nums { - let new_agent: Result, String> = match agent_num { - 1 => Ok(Box::new(DQLAgent::new( - 1., - batch_size, - new(0.001, batch_size), - ))), - 2 => Ok(Box::new(QLAgent::new(1., 6 * 6))), - 3 => Ok(Box::new(RandomAgent::new())), - 4 => Ok(Box::new(HumanPlayer::new())), - _ => Err("Only implemented agents 1-4!".to_string()), - }; + let new_agent: Result, Ordinal>>, String> = + match agent_num { + 1 => Ok(Box::new(DQLAgent::new( + 1., + batch_size, + new(0.001, batch_size), + ))), + 2 => Ok(Box::new(QLAgent::new(1., 6 * 6))), + 3 => Ok(Box::new(RandomAgent::new())), + 4 => Ok(Box::new(HumanPlayer::new())), + _ => Err("Only implemented agents 1-4!".to_string()), + }; res.push(new_agent?); } Ok(res) diff --git a/examples/tictactoe.rs b/examples/tictactoe.rs index d0025cf..68f5d5f 100644 --- a/examples/tictactoe.rs +++ b/examples/tictactoe.rs @@ -2,6 +2,8 @@ use agent::*; use env::TicTacToe; use rust_rl::network::nn::NeuralNetwork; use rust_rl::rl::{agent, env, training}; +use spaces::discrete::Ordinal; +use spaces::*; use std::io; use training::{utils, Trainer}; @@ -56,21 +58,24 @@ pub fn main() { ); } -fn get_agents(agent_nums: Vec) -> Result>, String> { - let mut res: Vec> = vec![]; +fn get_agents( + agent_nums: Vec, +) -> Result, Ordinal>>>, String> { + let mut res: Vec, Ordinal>>> = vec![]; let batch_size = 16; for agent_num in agent_nums { - let new_agent: Result, String> = match agent_num { - 1 => Ok(Box::new(DQLAgent::new( - 1., - batch_size, - new(0.001, batch_size), - ))), - 2 => Ok(Box::new(QLAgent::new(1., 3 * 3))), - 3 => Ok(Box::new(RandomAgent::new())), - 4 => Ok(Box::new(HumanPlayer::new())), - _ => Err("Only implemented agents 1-4!".to_string()), - }; + let new_agent: Result, Ordinal>>, String> = + match agent_num { + 1 => Ok(Box::new(DQLAgent::new( + 1., + batch_size, + new(0.001, batch_size), + ))), + 2 => Ok(Box::new(QLAgent::new(1., 3 * 3))), + 3 => Ok(Box::new(RandomAgent::new())), + 4 => Ok(Box::new(HumanPlayer::new())), + _ => Err("Only implemented agents 1-4!".to_string()), + }; res.push(new_agent?); } Ok(res) diff --git a/src/rl/agent/agent_trait.rs b/src/rl/agent/agent_trait.rs index d91ea4c..ae981fa 100644 --- a/src/rl/agent/agent_trait.rs +++ b/src/rl/agent/agent_trait.rs @@ -1,7 +1,12 @@ use ndarray::{Array1, Array2}; +use spaces::Space; /// A trait including all functions required to train them. -pub trait Agent { +pub trait Agent +where + S: Space, + A: Space, +{ /// Returns a simple string identifying the specific agent type. fn get_id(&self) -> String; diff --git a/src/rl/agent/dql_agent.rs b/src/rl/agent/dql_agent.rs index ecb86c9..0f8871e 100644 --- a/src/rl/agent/dql_agent.rs +++ b/src/rl/agent/dql_agent.rs @@ -2,6 +2,7 @@ use crate::network::nn::NeuralNetwork; use crate::rl::agent::Agent; use crate::rl::algorithms::DQlearning; use ndarray::{Array1, Array2}; +use spaces::Space; /// An agent using Deep-Q-Learning, based on a small neural network. pub struct DQLAgent { @@ -19,7 +20,7 @@ impl DQLAgent { } } -impl Agent for DQLAgent { +impl Agent for DQLAgent { fn get_id(&self) -> String { "dqlearning agent".to_string() } diff --git a/src/rl/agent/human_player.rs b/src/rl/agent/human_player.rs index 14d99b0..06cfbf1 100644 --- a/src/rl/agent/human_player.rs +++ b/src/rl/agent/human_player.rs @@ -1,5 +1,6 @@ use crate::rl::agent::agent_trait::Agent; use ndarray::{Array1, Array2}; +use spaces::Space; use std::io; /// An agent which just shows the user the current environment and lets the user decide about each action. @@ -13,7 +14,7 @@ impl HumanPlayer { } } -impl Agent for HumanPlayer { +impl Agent for HumanPlayer { fn get_id(&self) -> String { "human player".to_string() } diff --git a/src/rl/agent/ql_agent.rs b/src/rl/agent/ql_agent.rs index 1fc41ad..928d300 100644 --- a/src/rl/agent/ql_agent.rs +++ b/src/rl/agent/ql_agent.rs @@ -2,6 +2,7 @@ use crate::rl::algorithms::Qlearning; use ndarray::{Array1, Array2}; use crate::rl::agent::Agent; +use spaces::Space; /// An agent working on a classical q-table. pub struct QLAgent { @@ -19,7 +20,7 @@ impl QLAgent { } } -impl Agent for QLAgent { +impl Agent for QLAgent { fn get_id(&self) -> String { "qlearning agent".to_string() } diff --git a/src/rl/agent/random_agent.rs b/src/rl/agent/random_agent.rs index 7f73286..0529412 100644 --- a/src/rl/agent/random_agent.rs +++ b/src/rl/agent/random_agent.rs @@ -1,6 +1,7 @@ use crate::rl::agent::agent_trait::Agent; use crate::rl::algorithms::utils; use ndarray::{Array1, Array2}; +use spaces::Space; /// An agent who acts randomly. /// @@ -16,7 +17,7 @@ impl RandomAgent { } } -impl Agent for RandomAgent { +impl Agent for RandomAgent { fn get_id(&self) -> String { "random agent".to_string() } diff --git a/src/rl/training/trainer.rs b/src/rl/training/trainer.rs index 0dff128..561ade2 100644 --- a/src/rl/training/trainer.rs +++ b/src/rl/training/trainer.rs @@ -10,14 +10,14 @@ where { env: Box>, res: Vec<(u32, u32, u32)>, - agents: Vec>, + agents: Vec>>, } impl Trainer { /// We construct a Trainer by passing a single environment and one or more (possibly different) agents. pub fn new( env: Box>, - agents: Vec>, + agents: Vec>>, ) -> Result { if agents.is_empty() { return Err("At least one agent required!".to_string()); From e1823071bd21b693df3b5995f6d66f16ef59432f Mon Sep 17 00:00:00 2001 From: drehwald Date: Sat, 12 Dec 2020 23:39:16 +0100 Subject: [PATCH 4/4] WIP spaces --- Cargo.toml | 5 +---- src/rl/agent/agent_trait.rs | 5 ++++- src/rl/agent/dql_agent.rs | 10 ++++++++-- src/rl/agent/human_player.rs | 5 ++++- src/rl/agent/ql_agent.rs | 8 ++++++-- src/rl/agent/random_agent.rs | 8 ++++++-- src/rl/env/env_trait.rs | 8 ++++++-- src/rl/env/fortress.rs | 7 ++++--- src/rl/env/tictactoe.rs | 7 ++++--- 9 files changed, 43 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8efa4c0..6c9775c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,15 +15,12 @@ rand = "0.7" fnv = "1.0.7" ndarray-rand = "0.11.0" ndarray-stats = "0.3.0" +spaces = "5.0.0" [features] default = [] download = ["datasets/download"] -[dependencies.spaces] -git = "https://github.com/tspooner/spaces" -branch = "master" - [dependencies.datasets] git = "https://github.com/ZuseZ4/Datasets" branch = "master" diff --git a/src/rl/agent/agent_trait.rs b/src/rl/agent/agent_trait.rs index ae981fa..7ba0efe 100644 --- a/src/rl/agent/agent_trait.rs +++ b/src/rl/agent/agent_trait.rs @@ -1,6 +1,9 @@ use ndarray::{Array1, Array2}; use spaces::Space; +pub type State = ::Value; +pub type Action = ::Value; + /// A trait including all functions required to train them. pub trait Agent where @@ -14,7 +17,7 @@ where /// /// The concrete encoding of actions as usize value has to be looked up in the documentation of the specific environment. /// Advanced agents shouldn't need knowledge about the used encoding. - fn get_move(&mut self, env: Array2, actions: Array1, reward: f32) -> usize; + fn get_move(&mut self, env: Array2, actions: Array1, reward: f32) -> &Action; /// Informs the agent that the current epoch has finished and tells him about his final result. /// diff --git a/src/rl/agent/dql_agent.rs b/src/rl/agent/dql_agent.rs index 0f8871e..80d76ac 100644 --- a/src/rl/agent/dql_agent.rs +++ b/src/rl/agent/dql_agent.rs @@ -3,6 +3,11 @@ use crate::rl::agent::Agent; use crate::rl::algorithms::DQlearning; use ndarray::{Array1, Array2}; use spaces::Space; +use spaces::discrete::NonNegativeIntegers; + + +pub type State = ::Value; +pub type Action = ::Value; /// An agent using Deep-Q-Learning, based on a small neural network. pub struct DQLAgent { @@ -30,8 +35,9 @@ impl Agent for DQLAgent { self.dqlearning.finish_round(result, final_state); } - fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> usize { - self.dqlearning.get_move(board, actions, reward) + fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> &Action { + let res = self.dqlearning.get_move(board, actions, reward); + res } fn get_learning_rate(&self) -> f32 { diff --git a/src/rl/agent/human_player.rs b/src/rl/agent/human_player.rs index 06cfbf1..2571222 100644 --- a/src/rl/agent/human_player.rs +++ b/src/rl/agent/human_player.rs @@ -3,6 +3,9 @@ use ndarray::{Array1, Array2}; use spaces::Space; use std::io; +pub type State = ::Value; +pub type Action = ::Value; + /// An agent which just shows the user the current environment and lets the user decide about each action. #[derive(Default)] pub struct HumanPlayer {} @@ -19,7 +22,7 @@ impl Agent for HumanPlayer { "human player".to_string() } - fn get_move(&mut self, board: Array2, actions: Array1, _: f32) -> usize { + fn get_move(&mut self, board: Array2, actions: Array1, _: f32) -> &Action { let (n, m) = (board.shape()[0], board.shape()[1]); for i in 0..n { for j in 0..m { diff --git a/src/rl/agent/ql_agent.rs b/src/rl/agent/ql_agent.rs index 928d300..5c441a1 100644 --- a/src/rl/agent/ql_agent.rs +++ b/src/rl/agent/ql_agent.rs @@ -4,6 +4,9 @@ use ndarray::{Array1, Array2}; use crate::rl::agent::Agent; use spaces::Space; +pub type State = ::Value; +pub type Action = ::Value; + /// An agent working on a classical q-table. pub struct QLAgent { qlearning: Qlearning, @@ -30,8 +33,9 @@ impl Agent for QLAgent { self.qlearning.finish_round(result, final_state); } - fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> usize { - self.qlearning.get_move(board, actions, reward) + fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> &Action { + let res = self.qlearning.get_move(board, actions, reward); + res } fn set_learning_rate(&mut self, lr: f32) -> Result<(), String> { diff --git a/src/rl/agent/random_agent.rs b/src/rl/agent/random_agent.rs index 0529412..f29f9bd 100644 --- a/src/rl/agent/random_agent.rs +++ b/src/rl/agent/random_agent.rs @@ -3,6 +3,9 @@ use crate::rl::algorithms::utils; use ndarray::{Array1, Array2}; use spaces::Space; +pub type State = ::Value; +pub type Action = ::Value; + /// An agent who acts randomly. /// /// All input is ignored except of the vector of possible actions. @@ -22,8 +25,9 @@ impl Agent for RandomAgent { "random agent".to_string() } - fn get_move(&mut self, _: Array2, actions: Array1, _: f32) -> usize { - utils::get_random_true_entry(actions) + fn get_move(&mut self, _: Array2, actions: Array1, _: f32) -> &Action { + let res = utils::get_random_true_entry(actions); + res } fn finish_round(&mut self, _single_res: i32, _final_state: Array2) {} diff --git a/src/rl/env/env_trait.rs b/src/rl/env/env_trait.rs index b61f89c..1cb3953 100644 --- a/src/rl/env/env_trait.rs +++ b/src/rl/env/env_trait.rs @@ -1,6 +1,10 @@ -use ndarray::{Array1, Array2}; +use ndarray::Array2; +use ndarray::Array1; use spaces::Space; +pub type State = ::Value; +pub type Action = ::Value; + /// This trait defines all functions on which agents and other user might depend. pub trait Environment { /// State representation @@ -20,7 +24,7 @@ pub trait Environment { /// /// If the action is allowed for the currently active agent then update the environment and return true. /// Otherwise do nothing and return false. The same agent can then try a new move. - fn take_action(&mut self, action: usize) -> bool; + fn take_action(&mut self, action: &Action) -> bool; /// Shows the current envrionment state in a graphical way. /// /// The representation is environment specific and might be either by terminal, or in an extra window. diff --git a/src/rl/env/fortress.rs b/src/rl/env/fortress.rs index abb7017..7e1b3b1 100644 --- a/src/rl/env/fortress.rs +++ b/src/rl/env/fortress.rs @@ -1,6 +1,6 @@ use crate::rl::env::env_trait::Environment; use ndarray::{Array, Array1, Array2}; -use spaces::discrete::Ordinal; +use spaces::discrete::NonNegativeIntegers; use spaces::*; use std::cmp::Ordering; @@ -57,7 +57,7 @@ pub struct Fortress { impl Environment for Fortress { type StateSpace = ProductSpace; - type ActionSpace = Ordinal; + type ActionSpace = NonNegativeIntegers; fn step(&self) -> (Array2, Array1, f32, bool) { if !self.active { eprintln!("Warning, calling step() after done = true!"); @@ -101,7 +101,8 @@ impl Environment for Fortress { println!(); } - fn take_action(&mut self, pos: usize) -> bool { + fn take_action(&mut self, pos: &u64) -> bool { + let pos = *pos as usize; let player_val = if self.first_player_turn { 1 } else { -1 }; // check that field is not controlled by enemy, no enemy building on field, no own building on max lv (3) already exists diff --git a/src/rl/env/tictactoe.rs b/src/rl/env/tictactoe.rs index 32ae530..e64dce3 100644 --- a/src/rl/env/tictactoe.rs +++ b/src/rl/env/tictactoe.rs @@ -1,6 +1,6 @@ use crate::rl::env::env_trait::Environment; use ndarray::{Array, Array1, Array2}; -use spaces::discrete::Ordinal; +use spaces::discrete::NonNegativeIntegers; use spaces::*; static BITMASKS: [&[u16]; 9] = [ @@ -42,7 +42,7 @@ impl Default for TicTacToe { impl Environment for TicTacToe { type StateSpace = ProductSpace; - type ActionSpace = Ordinal; + type ActionSpace = NonNegativeIntegers; fn step(&self) -> (Array2, Array1, f32, bool) { // storing current position into ndarray let position = board_as_arr(self.player1, self.player2) @@ -82,7 +82,8 @@ impl Environment for TicTacToe { } } - fn take_action(&mut self, pos: usize) -> bool { + fn take_action(&mut self, pos: &u64) -> bool { + let pos = *pos as usize; if pos > 8 { return false; }