From 39c89e1bb4e903d3df0aa29c9042e2ea6f4894bb Mon Sep 17 00:00:00 2001 From: smilesun Date: Sat, 29 Sep 2018 15:51:44 +0200 Subject: [PATCH 1/3] . --- .gitignore | 1 + DESCRIPTION | 8 +- R/replaymemdb.R | 224 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 R/replaymemdb.R diff --git a/.gitignore b/.gitignore index 0da876b..7f6f58f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ logout .Rhistory .RData .Ruserdata +replay_memory diff --git a/DESCRIPTION b/DESCRIPTION index 21f4e15..f02b74b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,10 @@ Package: rlR Type: Package Title: Reinforcement Learning in R Version: 0.1.0 -Authors@R: person("Xudong", "Sun", email = {"smilesun.east@gmail.com"}, role = c("aut", "cre")) +Authors@R: c( + person("Xudong", "Sun", email = {"smilesun.east@gmail.com"}, role = c("aut", "cre")), + person("Sebastian", "Gruber", email = {"gruber_sebastian@t-online.de"}, role = c("ctb")) + ) Maintainer: Xudong Sun Description: Reinforcement Learning with deep Q learning, double deep Q learning, frozen target deep Q learning, policy gradient deep learning, policy @@ -21,6 +24,9 @@ Imports: logging, ggplot2, openssl, + RSQLite, + png, + stringr, imager, magrittr, abind diff --git a/R/replaymemdb.R b/R/replaymemdb.R new file mode 100644 index 0000000..97f9538 --- /dev/null +++ b/R/replaymemdb.R @@ -0,0 +1,224 @@ +#' @importFrom magrittr %>% %<>% +ReplayMemDB = R6::R6Class( + "ReplayMemDB", + inherit = ReplayMem, + public = list( + dt = NULL, + len = NULL, + replayed.idx = NULL, + conf = NULL, + agent = NULL, + dt.temp = NULL, + smooth = NULL, + db.con = NULL, + table.name = NULL, + initialize = function(agent, conf) { + # initialize sqlite connection + self$db.con = RSQLite::dbConnect(RSQLite::SQLite(), dbname = "replay_memory") + # pick the env name as table name + self$table.name = agent$env$env %>% + stringr::str_extract("<([a-z]|[A-Z]|-|[0-9])*>") %>% + stringr::str_remove_all("<|>") # there's maybe a better solution + # delete old replay table + RSQLite::dbExecute( self$db.con, paste0("DROP TABLE IF EXISTS '", self$table.name, "'") ) + # manually create new table to specify primary key - this reduces index search complexity from O(n) to O(log n) + RSQLite::dbExecute( self$db.con, paste0(" + CREATE TABLE '", self$table.name, "' ( + state_id INTEGER PRIMARY KEY, + state_old TEXT, + reward NUMERIC, + action INTEGER, + state_new TEXT, + done INTEGER, + episode INTEGER, + stepidx INTEGER, + info TEXT ) + ") ) + self$smooth = rlR.conf4log[["replay.mem.laplace.smoother"]] + self$dt = data.table() + self$len = 0L + self$conf = conf + self$agent = agent + # helper constant variable + self$dt.temp = data.table("delta" = NA, "priorityRank" = NA, "priorityAbs" = NA, "priorityDelta2" = NA, "deltaOfdelta" = NA, "deltaOfdeltaPercentage" = NA) + self$dt.temp = self$dt.temp[, lapply(.SD, as.numeric)] + }, + + reset = function() { + RSQLite::dbExecute( self$db.con, paste0("DROP TABLE '", self$table.name, "'") ) + self$dt = data.table() + self$len = 0L + }, + + mkInst = function(state.old, action, reward, state.new, done, info) { + # transform/compress states into single string for DB entry + if (length(self$agent$state_dim) == 1) { + state.old %<>% paste(collapse = "_") + state.new %<>% paste(collapse = "_") + } else { + state.old = (state.old / 255L) %>% (png::writePNG) %>% paste(collapse = "") + state.new = (state.new / 255L) %>% (png::writePNG) %>% paste(collapse = "") + } + self$len = self$len + 1L + # don't use "." in column names - SQLite will throw up on it + data.frame( + state_id = self$len, + #state.hash = digest(old_state, algo = "md5"), + state_old = state.old, + reward = reward, + action = action, + state_new = state.new, + done = done, + episode = info$episode, + stepidx = info$stepidx, + info = if (length(info$info)==0) "NULL" else info$info %>% as.character() + ) + }, + + add = function(ins) { + # write to sqlite table + RSQLite::dbWriteTable( self$db.con, self$table.name, ins, append = TRUE ) + mdt = data.table(t(unlist(ins))) + mdt = cbind(mdt, self$dt.temp) + self$dt = rbindlist(list(self$dt, mdt), fill = TRUE) + }, + + updateDT = function(idx = NULL) { + if (is.null(idx)) idx = 1L:self$len + list.res = self$getSamples(idx) + td.list = lapply(list.res, self$agent$calculateTDError) + updatedTDError = unlist(td.list) + cat(sprintf("mean TD error: %f\n", mean(updatedTDError))) + old.delta = self$dt[idx, "delta"] + self$dt[idx, "delta"] = updatedTDError + self$updatePriority() + }, + + afterEpisode = function(interact) { + # do nothing + }, + + afterStep = function() { + # do nothing + }, + + updatePriority = function() { + self$dt[, "priorityAbs"] = abs(self$dt[, "delta"]) + self$smooth + self$dt[, "priorityRank"] = order(self$dt[, "delta"]) + }, + + getSamples = function(idx) { + + str_to_array = function(string) { + + if (length(self$agent$state_dim) == 1) { + strsplit(string, "_")[[1]] %>% + as.numeric() %>% + array() + } else if (length(self$agent$state_dim) %in% 2:4) { + change_storage = function(y) {storage.mode(y) <- "integer"; y} + ( + string %>% + strsplit("") %>% + .[[1]] %>% + (function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs + as.hexmode %>% # necessary for correct as.raw + as.raw %>% # make it readable as PNG + (png::readPNG) * 255 + ) %>% + change_storage %>% + array(dim = self$agent$state_dim) # this is necessary if state_dim has shape x1 x2 1 + } + } + + replay.samples = paste0(" + SELECT state_old, action, reward, state_new, done, info, episode, stepidx + FROM '", self$table.name, "' + WHERE state_id IN (", paste(idx, collapse = ", "), ") + ") %>% + RSQLite::dbGetQuery(conn = self$db.con) + + lapply(1:nrow(replay.samples), function(i) list( + state.old = replay.samples$state_old[i] %>% str_to_array, + action = replay.samples$action[i], + reward = replay.samples$reward[i], + state.new = replay.samples$state_new[i] %>% str_to_array, + done = replay.samples$done[i], + info = list( + episode = replay.samples$episode[i], + stepidx = replay.samples$stepidx[i], + info = replay.samples$info[i] + ) + )) + }, + + # TODO: implement way to pull the whole replay memory + # function taking a list of states (2d/3d/4d arrays) and transforming into video replay_.mp4 in their given order + # input arrays need at least 2 dimensions + # mp4 file is compressed -> information loss -> only makes sense for human eyes + createReplayVideo = function(name, start_state_id = 1, end_state_id, framerate = 25) { + # check if the mp4 file doesn't exist - otherwise ffmpeg will make issues + if (length(self$agent$state_dim) == 1) { + stop("State data format is not suitable for video creation") + + } else if (!file.exists( paste0(getwd(), "/replay_", name, ".mp4")) ) { + # get all states of the replay memory + states = self$getSamples(start_state_id:end_state_id) + + # create PNGs in a temporary directory + tempdir = tempdir() + for (i in 1:(end_state_id-start_state_id)) { + png::writePNG( + states[[i]]$state.old / 255, + target = paste0(tempdir, "/img", stringr::str_pad(i, 7, pad = "0"),".png") + ) + } + # use the tool ffmpeg to create a video out of PNGs + command = paste0( + "ffmpeg -framerate ", framerate, + " -i ", tempdir, "'/img%07d.png' -c:v libx264 -pix_fmt yuv420p ", + getwd(), "/replay_", name, ".mp4" + ) + system(command) + } else { + stop(paste0("The file ", getwd(), "/replay_", name, ".mp4 already exists!")) + } + } + ), + private = list(), + active = list() +) + + +ReplayMemUniformDB = R6::R6Class("ReplayMemUniformDB", + inherit = ReplayMemDB, + public = list( + sample.fun = function(k) { + k = min(k, self$len) + self$replayed.idx = sample(self$len)[1L:k] + list.res = self$getSamples(self$replayed.idx) + return(list.res) + } + ), + private = list(), + active = list() + ) + + +test.run = function(sname, runs, nodes) { + conf = rlR:::RLConf$new( + render = TRUE, + console = FALSE, + log = FALSE, + policy.maxEpsilon = 1, + policy.minEpsilon = 0.001, + policy.decay = exp(-0.001), + policy.name = "EpsilonGreedy", + replay.batchsize = 64L, + replay.memname = "UniformDB", + agent.nn.arch = list(nhidden = nodes, act1 = "relu", act2 = "linear", loss = "mse", lr = 0.00025, kernel_regularizer = "regularizer_l2(l=0.0)", bias_regularizer = "regularizer_l2(l=0.0)")) + + interact = makeGymExperiment(sname = sname, aname = "AgentDQN", conf = conf) + interact$run(runs) +} + From baa64465a12bb1c53186e0ba97e9737601cfbd63 Mon Sep 17 00:00:00 2001 From: smilesun Date: Sat, 29 Sep 2018 16:44:59 +0200 Subject: [PATCH 2/3] clean up but still works --- R/replaymemdb.R | 118 +++++++++--------------------------------------- 1 file changed, 22 insertions(+), 96 deletions(-) diff --git a/R/replaymemdb.R b/R/replaymemdb.R index 97f9538..5f549c9 100644 --- a/R/replaymemdb.R +++ b/R/replaymemdb.R @@ -3,21 +3,18 @@ ReplayMemDB = R6::R6Class( "ReplayMemDB", inherit = ReplayMem, public = list( - dt = NULL, len = NULL, replayed.idx = NULL, conf = NULL, agent = NULL, - dt.temp = NULL, - smooth = NULL, db.con = NULL, table.name = NULL, initialize = function(agent, conf) { # initialize sqlite connection - self$db.con = RSQLite::dbConnect(RSQLite::SQLite(), dbname = "replay_memory") + self$db.con = RSQLite::dbConnect(RSQLite::SQLite(), dbname = "replay_memory") # RSQLite::SQLite() load the driver # pick the env name as table name self$table.name = agent$env$env %>% - stringr::str_extract("<([a-z]|[A-Z]|-|[0-9])*>") %>% + stringr::str_extract("<([a-z]|[A-Z]|-|[0-9])*>") %>% # remove the special signs stringr::str_remove_all("<|>") # there's maybe a better solution # delete old replay table RSQLite::dbExecute( self$db.con, paste0("DROP TABLE IF EXISTS '", self$table.name, "'") ) @@ -34,19 +31,13 @@ ReplayMemDB = R6::R6Class( stepidx INTEGER, info TEXT ) ") ) - self$smooth = rlR.conf4log[["replay.mem.laplace.smoother"]] - self$dt = data.table() self$len = 0L self$conf = conf self$agent = agent - # helper constant variable - self$dt.temp = data.table("delta" = NA, "priorityRank" = NA, "priorityAbs" = NA, "priorityDelta2" = NA, "deltaOfdelta" = NA, "deltaOfdeltaPercentage" = NA) - self$dt.temp = self$dt.temp[, lapply(.SD, as.numeric)] }, reset = function() { RSQLite::dbExecute( self$db.con, paste0("DROP TABLE '", self$table.name, "'") ) - self$dt = data.table() self$len = 0L }, @@ -77,21 +68,7 @@ ReplayMemDB = R6::R6Class( add = function(ins) { # write to sqlite table - RSQLite::dbWriteTable( self$db.con, self$table.name, ins, append = TRUE ) - mdt = data.table(t(unlist(ins))) - mdt = cbind(mdt, self$dt.temp) - self$dt = rbindlist(list(self$dt, mdt), fill = TRUE) - }, - - updateDT = function(idx = NULL) { - if (is.null(idx)) idx = 1L:self$len - list.res = self$getSamples(idx) - td.list = lapply(list.res, self$agent$calculateTDError) - updatedTDError = unlist(td.list) - cat(sprintf("mean TD error: %f\n", mean(updatedTDError))) - old.delta = self$dt[idx, "delta"] - self$dt[idx, "delta"] = updatedTDError - self$updatePriority() + RSQLite::dbWriteTable( self$db.con, self$table.name, ins, append = TRUE) }, afterEpisode = function(interact) { @@ -102,32 +79,33 @@ ReplayMemDB = R6::R6Class( # do nothing }, - updatePriority = function() { - self$dt[, "priorityAbs"] = abs(self$dt[, "delta"]) + self$smooth - self$dt[, "priorityRank"] = order(self$dt[, "delta"]) - }, - getSamples = function(idx) { str_to_array = function(string) { - if (length(self$agent$state_dim) == 1) { - strsplit(string, "_")[[1]] %>% + if (length(self$agent$state_dim) == 1) { + # if order of tensor is only 1, which means flat linear state + strsplit(string, "_")[[1]] %>% # self defined format of the string, now split it by spliter '_' as.numeric() %>% array() - } else if (length(self$agent$state_dim) %in% 2:4) { - change_storage = function(y) {storage.mode(y) <- "integer"; y} - ( + } else if (length(self$agent$state_dim) %in% 2L:3L) { + change_storage = function(y) { + storage.mode(y) = "integer" # change storage type to integer to save space + y + } + ( + # magittr require () string %>% - strsplit("") %>% - .[[1]] %>% - (function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs - as.hexmode %>% # necessary for correct as.raw + strsplit("") %>% # ABEF39 SPLIT into c("A", "B", "E", ...) + .[[1]] %>% # return of split is a list + (function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs, equivalent to zip: x[c(TRUE, FALSE)] takes the 1st,3st,5st and x[c(FALSE, TRUE)] take the 2st, 4st + as.hexmode %>% # necessary for correct as.raw. For R to understand this is hexcode other than String. as.raw %>% # make it readable as PNG - (png::readPNG) * 255 + (png::readPNG) * 255 # png package assums image to have range 0-1 ) %>% - change_storage %>% + change_storage %>% # float storage to int storage array(dim = self$agent$state_dim) # this is necessary if state_dim has shape x1 x2 1 + #FIXME: IS THE Orientation of the array right! Critically Important } } @@ -137,7 +115,7 @@ ReplayMemDB = R6::R6Class( WHERE state_id IN (", paste(idx, collapse = ", "), ") ") %>% RSQLite::dbGetQuery(conn = self$db.con) - + # replay.samples now are the results from the query lapply(1:nrow(replay.samples), function(i) list( state.old = replay.samples$state_old[i] %>% str_to_array, action = replay.samples$action[i], @@ -150,41 +128,8 @@ ReplayMemDB = R6::R6Class( info = replay.samples$info[i] ) )) - }, - - # TODO: implement way to pull the whole replay memory - # function taking a list of states (2d/3d/4d arrays) and transforming into video replay_.mp4 in their given order - # input arrays need at least 2 dimensions - # mp4 file is compressed -> information loss -> only makes sense for human eyes - createReplayVideo = function(name, start_state_id = 1, end_state_id, framerate = 25) { - # check if the mp4 file doesn't exist - otherwise ffmpeg will make issues - if (length(self$agent$state_dim) == 1) { - stop("State data format is not suitable for video creation") - - } else if (!file.exists( paste0(getwd(), "/replay_", name, ".mp4")) ) { - # get all states of the replay memory - states = self$getSamples(start_state_id:end_state_id) - - # create PNGs in a temporary directory - tempdir = tempdir() - for (i in 1:(end_state_id-start_state_id)) { - png::writePNG( - states[[i]]$state.old / 255, - target = paste0(tempdir, "/img", stringr::str_pad(i, 7, pad = "0"),".png") - ) - } - # use the tool ffmpeg to create a video out of PNGs - command = paste0( - "ffmpeg -framerate ", framerate, - " -i ", tempdir, "'/img%07d.png' -c:v libx264 -pix_fmt yuv420p ", - getwd(), "/replay_", name, ".mp4" - ) - system(command) - } else { - stop(paste0("The file ", getwd(), "/replay_", name, ".mp4 already exists!")) - } } - ), + ), private = list(), active = list() ) @@ -203,22 +148,3 @@ ReplayMemUniformDB = R6::R6Class("ReplayMemUniformDB", private = list(), active = list() ) - - -test.run = function(sname, runs, nodes) { - conf = rlR:::RLConf$new( - render = TRUE, - console = FALSE, - log = FALSE, - policy.maxEpsilon = 1, - policy.minEpsilon = 0.001, - policy.decay = exp(-0.001), - policy.name = "EpsilonGreedy", - replay.batchsize = 64L, - replay.memname = "UniformDB", - agent.nn.arch = list(nhidden = nodes, act1 = "relu", act2 = "linear", loss = "mse", lr = 0.00025, kernel_regularizer = "regularizer_l2(l=0.0)", bias_regularizer = "regularizer_l2(l=0.0)")) - - interact = makeGymExperiment(sname = sname, aname = "AgentDQN", conf = conf) - interact$run(runs) -} - From 8be824702adc93b6cf99e363e73a97af234143f6 Mon Sep 17 00:00:00 2001 From: smilesun Date: Sat, 29 Sep 2018 17:25:49 +0200 Subject: [PATCH 3/3] in memory png compression storation success but with critical bug --- R/replaymemdb.R | 76 +++++++++++++++++++++++ tests/testthat/test_file_replay_mem_png.R | 11 ++++ 2 files changed, 87 insertions(+) create mode 100644 tests/testthat/test_file_replay_mem_png.R diff --git a/R/replaymemdb.R b/R/replaymemdb.R index 5f549c9..7d4a05a 100644 --- a/R/replaymemdb.R +++ b/R/replaymemdb.R @@ -1,4 +1,5 @@ #' @importFrom magrittr %>% %<>% +#FIXME: RSQLite::dbDisconnect(agent$mem$db.con) run this at the end of try-catch-finally ReplayMemDB = R6::R6Class( "ReplayMemDB", inherit = ReplayMem, @@ -148,3 +149,78 @@ ReplayMemUniformDB = R6::R6Class("ReplayMemUniformDB", private = list(), active = list() ) + + +ReplayMemPng = R6::R6Class( + "ReplayMemPng", + inherit = ReplayMemUniform, + public = list( + len = NULL, + replayed.idx = NULL, + conf = NULL, + agent = NULL, + table.name = NULL, + initialize = function(agent, conf) { + super$initialize(agent, conf) + }, + + mkInst = function(state.old, action, reward, state.new, done, info) { + # transform/compress states into single string for DB entry + if (length(self$agent$state_dim) == 1) { + state.old %<>% paste(collapse = "_") + state.new %<>% paste(collapse = "_") + } else { + state.old = (state.old / 255L) %>% (png::writePNG) %>% paste(collapse = "") + state.new = (state.new / 255L) %>% (png::writePNG) %>% paste(collapse = "") + } + super$mkInst(state.old, action, reward, state.new, done, info) + }, + + sample.fun = function(k) { + k = min(k, self$size) + #FIXME: the replayed.idx are not natural index, but just the position in the replay memory + self$replayed.idx = sample(self$size)[1L:k] + list.res = lapply(self$replayed.idx, function(x) self$samples[[x]]) + replay.samples = list.res + # replay.samples now are the results from the query + + #FIXME: IS THE Orientation of the array right! Critically Important + list.replay = lapply(replay.samples, function(x) list( + state.old = x$state.old %>% str_to_array_h %>% array(dim = self$agent$state_dim), + action = x$action, + reward = x$reward, + state.new = x$state.new %>% str_to_array_h %>% array(dim = self$agent$state_dim), + done = x$done, + info = list( + episode = x$episode, + stepidx = x$stepidx, + info = x$info + ) + )) + list.replay # DEBUG: self$agent$env$showImage(list.replay[[64]][["state.new"]]) make sense + } + ), + private = list(), + active = list() +) + + + +change_storage = function(y) { + storage.mode(y) = "integer" # change storage type to integer to save space + y +} + +str_to_array_h = function(string) { + ( + # magittr require () + string %>% + strsplit("") %>% # ABEF39 SPLIT into c("A", "B", "E", ...) + .[[1]] %>% # return of split is a list + (function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs, equivalent to zip: x[c(TRUE, FALSE)] takes the 1st,3st,5st and x[c(FALSE, TRUE)] take the 2st, 4st + as.hexmode %>% # necessary for correct as.raw. For R to understand this is hexcode other than String. + as.raw %>% # make it readable as PNG + (png::readPNG) * 255 # png package assums image to have range 0-1 + ) %>% + change_storage # float storage to int storage +} diff --git a/tests/testthat/test_file_replay_mem_png.R b/tests/testthat/test_file_replay_mem_png.R new file mode 100644 index 0000000..0f80998 --- /dev/null +++ b/tests/testthat/test_file_replay_mem_png.R @@ -0,0 +1,11 @@ +context("replay_mem") +test_that("test basic replay_mem works", { +env = makeGymEnv("Pong-v0") +env$overview() +conf = getDefaultConf("AgentDQN") +conf$show() +conf$set(render = T, console = T, replay.memname = "Png") +agent = makeAgent("AgentDQN", env, conf) +perf = agent$learn(1L) +agent$plotPerf() +})