diff --git a/.gitignore b/.gitignore index 0da876b..7f6f58f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ logout .Rhistory .RData .Ruserdata +replay_memory diff --git a/DESCRIPTION b/DESCRIPTION index 486c22b..608c0de 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,10 @@ Package: rlR Type: Package Title: Reinforcement Learning in R Version: 0.1.0 -Authors@R: person("Xudong", "Sun", email = {"smilesun.east@gmail.com"}, role = c("aut", "cre")) +Authors@R: c( + person("Xudong", "Sun", email = {"smilesun.east@gmail.com"}, role = c("aut", "cre")), + person("Sebastian", "Gruber", email = {"gruber_sebastian@t-online.de"}, role = c("ctb")) + ) Maintainer: Xudong Sun Description: Reinforcement Learning with deep Q learning, double deep Q learning, frozen target deep Q learning, policy gradient deep learning, policy @@ -21,6 +24,9 @@ Imports: logging, ggplot2, openssl, + RSQLite, + png, + stringr, magrittr, abind LazyData: true diff --git a/R/replaymemdb.R b/R/replaymemdb.R new file mode 100644 index 0000000..7d4a05a --- /dev/null +++ b/R/replaymemdb.R @@ -0,0 +1,226 @@ +#' @importFrom magrittr %>% %<>% +#FIXME: RSQLite::dbDisconnect(agent$mem$db.con) run this at the end of try-catch-finally +ReplayMemDB = R6::R6Class( + "ReplayMemDB", + inherit = ReplayMem, + public = list( + len = NULL, + replayed.idx = NULL, + conf = NULL, + agent = NULL, + db.con = NULL, + table.name = NULL, + initialize = function(agent, conf) { + # initialize sqlite connection + self$db.con = RSQLite::dbConnect(RSQLite::SQLite(), dbname = "replay_memory") # RSQLite::SQLite() load the driver + # pick the env name as table name + self$table.name = agent$env$env %>% + stringr::str_extract("<([a-z]|[A-Z]|-|[0-9])*>") %>% # remove the special signs + stringr::str_remove_all("<|>") # there's maybe a better solution + # delete old replay table + RSQLite::dbExecute( self$db.con, paste0("DROP TABLE IF EXISTS '", self$table.name, "'") ) + # manually create new table to specify primary key - this reduces index search complexity from O(n) to O(log n) + RSQLite::dbExecute( self$db.con, paste0(" + CREATE TABLE '", self$table.name, "' ( + state_id INTEGER PRIMARY KEY, + state_old TEXT, + reward NUMERIC, + action INTEGER, + state_new TEXT, + done INTEGER, + episode INTEGER, + stepidx INTEGER, + info TEXT ) + ") ) + self$len = 0L + self$conf = conf + self$agent = agent + }, + + reset = function() { + RSQLite::dbExecute( self$db.con, paste0("DROP TABLE '", self$table.name, "'") ) + self$len = 0L + }, + + mkInst = function(state.old, action, reward, state.new, done, info) { + # transform/compress states into single string for DB entry + if (length(self$agent$state_dim) == 1) { + state.old %<>% paste(collapse = "_") + state.new %<>% paste(collapse = "_") + } else { + state.old = (state.old / 255L) %>% (png::writePNG) %>% paste(collapse = "") + state.new = (state.new / 255L) %>% (png::writePNG) %>% paste(collapse = "") + } + self$len = self$len + 1L + # don't use "." in column names - SQLite will throw up on it + data.frame( + state_id = self$len, + #state.hash = digest(old_state, algo = "md5"), + state_old = state.old, + reward = reward, + action = action, + state_new = state.new, + done = done, + episode = info$episode, + stepidx = info$stepidx, + info = if (length(info$info)==0) "NULL" else info$info %>% as.character() + ) + }, + + add = function(ins) { + # write to sqlite table + RSQLite::dbWriteTable( self$db.con, self$table.name, ins, append = TRUE) + }, + + afterEpisode = function(interact) { + # do nothing + }, + + afterStep = function() { + # do nothing + }, + + getSamples = function(idx) { + + str_to_array = function(string) { + + if (length(self$agent$state_dim) == 1) { + # if order of tensor is only 1, which means flat linear state + strsplit(string, "_")[[1]] %>% # self defined format of the string, now split it by spliter '_' + as.numeric() %>% + array() + } else if (length(self$agent$state_dim) %in% 2L:3L) { + change_storage = function(y) { + storage.mode(y) = "integer" # change storage type to integer to save space + y + } + ( + # magittr require () + string %>% + strsplit("") %>% # ABEF39 SPLIT into c("A", "B", "E", ...) + .[[1]] %>% # return of split is a list + (function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs, equivalent to zip: x[c(TRUE, FALSE)] takes the 1st,3st,5st and x[c(FALSE, TRUE)] take the 2st, 4st + as.hexmode %>% # necessary for correct as.raw. For R to understand this is hexcode other than String. + as.raw %>% # make it readable as PNG + (png::readPNG) * 255 # png package assums image to have range 0-1 + ) %>% + change_storage %>% # float storage to int storage + array(dim = self$agent$state_dim) # this is necessary if state_dim has shape x1 x2 1 + #FIXME: IS THE Orientation of the array right! Critically Important + } + } + + replay.samples = paste0(" + SELECT state_old, action, reward, state_new, done, info, episode, stepidx + FROM '", self$table.name, "' + WHERE state_id IN (", paste(idx, collapse = ", "), ") + ") %>% + RSQLite::dbGetQuery(conn = self$db.con) + # replay.samples now are the results from the query + lapply(1:nrow(replay.samples), function(i) list( + state.old = replay.samples$state_old[i] %>% str_to_array, + action = replay.samples$action[i], + reward = replay.samples$reward[i], + state.new = replay.samples$state_new[i] %>% str_to_array, + done = replay.samples$done[i], + info = list( + episode = replay.samples$episode[i], + stepidx = replay.samples$stepidx[i], + info = replay.samples$info[i] + ) + )) + } + ), + private = list(), + active = list() +) + + +ReplayMemUniformDB = R6::R6Class("ReplayMemUniformDB", + inherit = ReplayMemDB, + public = list( + sample.fun = function(k) { + k = min(k, self$len) + self$replayed.idx = sample(self$len)[1L:k] + list.res = self$getSamples(self$replayed.idx) + return(list.res) + } + ), + private = list(), + active = list() + ) + + +ReplayMemPng = R6::R6Class( + "ReplayMemPng", + inherit = ReplayMemUniform, + public = list( + len = NULL, + replayed.idx = NULL, + conf = NULL, + agent = NULL, + table.name = NULL, + initialize = function(agent, conf) { + super$initialize(agent, conf) + }, + + mkInst = function(state.old, action, reward, state.new, done, info) { + # transform/compress states into single string for DB entry + if (length(self$agent$state_dim) == 1) { + state.old %<>% paste(collapse = "_") + state.new %<>% paste(collapse = "_") + } else { + state.old = (state.old / 255L) %>% (png::writePNG) %>% paste(collapse = "") + state.new = (state.new / 255L) %>% (png::writePNG) %>% paste(collapse = "") + } + super$mkInst(state.old, action, reward, state.new, done, info) + }, + + sample.fun = function(k) { + k = min(k, self$size) + #FIXME: the replayed.idx are not natural index, but just the position in the replay memory + self$replayed.idx = sample(self$size)[1L:k] + list.res = lapply(self$replayed.idx, function(x) self$samples[[x]]) + replay.samples = list.res + # replay.samples now are the results from the query + + #FIXME: IS THE Orientation of the array right! Critically Important + list.replay = lapply(replay.samples, function(x) list( + state.old = x$state.old %>% str_to_array_h %>% array(dim = self$agent$state_dim), + action = x$action, + reward = x$reward, + state.new = x$state.new %>% str_to_array_h %>% array(dim = self$agent$state_dim), + done = x$done, + info = list( + episode = x$episode, + stepidx = x$stepidx, + info = x$info + ) + )) + list.replay # DEBUG: self$agent$env$showImage(list.replay[[64]][["state.new"]]) make sense + } + ), + private = list(), + active = list() +) + + + +change_storage = function(y) { + storage.mode(y) = "integer" # change storage type to integer to save space + y +} + +str_to_array_h = function(string) { + ( + # magittr require () + string %>% + strsplit("") %>% # ABEF39 SPLIT into c("A", "B", "E", ...) + .[[1]] %>% # return of split is a list + (function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs, equivalent to zip: x[c(TRUE, FALSE)] takes the 1st,3st,5st and x[c(FALSE, TRUE)] take the 2st, 4st + as.hexmode %>% # necessary for correct as.raw. For R to understand this is hexcode other than String. + as.raw %>% # make it readable as PNG + (png::readPNG) * 255 # png package assums image to have range 0-1 + ) %>% + change_storage # float storage to int storage +} diff --git a/tests/testthat/test_file_replay_mem_png.R b/tests/testthat/test_file_replay_mem_png.R new file mode 100644 index 0000000..0f80998 --- /dev/null +++ b/tests/testthat/test_file_replay_mem_png.R @@ -0,0 +1,11 @@ +context("replay_mem") +test_that("test basic replay_mem works", { +env = makeGymEnv("Pong-v0") +env$overview() +conf = getDefaultConf("AgentDQN") +conf$show() +conf$set(render = T, console = T, replay.memname = "Png") +agent = makeAgent("AgentDQN", env, conf) +perf = agent$learn(1L) +agent$plotPerf() +})