Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ logout
.Rhistory
.RData
.Ruserdata
replay_memory
8 changes: 7 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ Package: rlR
Type: Package
Title: Reinforcement Learning in R
Version: 0.1.0
Authors@R: person("Xudong", "Sun", email = {"smilesun.east@gmail.com"}, role = c("aut", "cre"))
Authors@R: c(
person("Xudong", "Sun", email = {"smilesun.east@gmail.com"}, role = c("aut", "cre")),
person("Sebastian", "Gruber", email = {"gruber_sebastian@t-online.de"}, role = c("ctb"))
)
Maintainer: Xudong Sun <smilesun.east@gmail.com>
Description: Reinforcement Learning with deep Q learning, double deep Q
learning, frozen target deep Q learning, policy gradient deep learning, policy
Expand All @@ -21,6 +24,9 @@ Imports:
logging,
ggplot2,
openssl,
RSQLite,
png,
stringr,
magrittr,
abind
LazyData: true
Expand Down
226 changes: 226 additions & 0 deletions R/replaymemdb.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#' @importFrom magrittr %>% %<>%
#FIXME: RSQLite::dbDisconnect(agent$mem$db.con) run this at the end of try-catch-finally
ReplayMemDB = R6::R6Class(
"ReplayMemDB",
inherit = ReplayMem,
public = list(
len = NULL,
replayed.idx = NULL,
conf = NULL,
agent = NULL,
db.con = NULL,
table.name = NULL,
initialize = function(agent, conf) {
# initialize sqlite connection
self$db.con = RSQLite::dbConnect(RSQLite::SQLite(), dbname = "replay_memory") # RSQLite::SQLite() load the driver
# pick the env name as table name
self$table.name = agent$env$env %>%
stringr::str_extract("<([a-z]|[A-Z]|-|[0-9])*>") %>% # remove the special signs
stringr::str_remove_all("<|>") # there's maybe a better solution
# delete old replay table
RSQLite::dbExecute( self$db.con, paste0("DROP TABLE IF EXISTS '", self$table.name, "'") )
# manually create new table to specify primary key - this reduces index search complexity from O(n) to O(log n)
RSQLite::dbExecute( self$db.con, paste0("
CREATE TABLE '", self$table.name, "' (
state_id INTEGER PRIMARY KEY,
state_old TEXT,
reward NUMERIC,
action INTEGER,
state_new TEXT,
done INTEGER,
episode INTEGER,
stepidx INTEGER,
info TEXT )
") )
self$len = 0L
self$conf = conf
self$agent = agent
},

reset = function() {
RSQLite::dbExecute( self$db.con, paste0("DROP TABLE '", self$table.name, "'") )
self$len = 0L
},

mkInst = function(state.old, action, reward, state.new, done, info) {
# transform/compress states into single string for DB entry
if (length(self$agent$state_dim) == 1) {
state.old %<>% paste(collapse = "_")
state.new %<>% paste(collapse = "_")
} else {
state.old = (state.old / 255L) %>% (png::writePNG) %>% paste(collapse = "")
state.new = (state.new / 255L) %>% (png::writePNG) %>% paste(collapse = "")
}
self$len = self$len + 1L
# don't use "." in column names - SQLite will throw up on it
data.frame(
state_id = self$len,
#state.hash = digest(old_state, algo = "md5"),
state_old = state.old,
reward = reward,
action = action,
state_new = state.new,
done = done,
episode = info$episode,
stepidx = info$stepidx,
info = if (length(info$info)==0) "NULL" else info$info %>% as.character()
)
},

add = function(ins) {
# write to sqlite table
RSQLite::dbWriteTable( self$db.con, self$table.name, ins, append = TRUE)
},

afterEpisode = function(interact) {
# do nothing
},

afterStep = function() {
# do nothing
},

getSamples = function(idx) {

str_to_array = function(string) {

if (length(self$agent$state_dim) == 1) {
# if order of tensor is only 1, which means flat linear state
strsplit(string, "_")[[1]] %>% # self defined format of the string, now split it by spliter '_'
as.numeric() %>%
array()
} else if (length(self$agent$state_dim) %in% 2L:3L) {
change_storage = function(y) {
storage.mode(y) = "integer" # change storage type to integer to save space
y
}
(
# magittr require ()
string %>%
strsplit("") %>% # ABEF39 SPLIT into c("A", "B", "E", ...)
.[[1]] %>% # return of split is a list
(function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs, equivalent to zip: x[c(TRUE, FALSE)] takes the 1st,3st,5st and x[c(FALSE, TRUE)] take the 2st, 4st
as.hexmode %>% # necessary for correct as.raw. For R to understand this is hexcode other than String.
as.raw %>% # make it readable as PNG
(png::readPNG) * 255 # png package assums image to have range 0-1
) %>%
change_storage %>% # float storage to int storage
array(dim = self$agent$state_dim) # this is necessary if state_dim has shape x1 x2 1
#FIXME: IS THE Orientation of the array right! Critically Important
}
}

replay.samples = paste0("
SELECT state_old, action, reward, state_new, done, info, episode, stepidx
FROM '", self$table.name, "'
WHERE state_id IN (", paste(idx, collapse = ", "), ")
") %>%
RSQLite::dbGetQuery(conn = self$db.con)
# replay.samples now are the results from the query
lapply(1:nrow(replay.samples), function(i) list(
state.old = replay.samples$state_old[i] %>% str_to_array,
action = replay.samples$action[i],
reward = replay.samples$reward[i],
state.new = replay.samples$state_new[i] %>% str_to_array,
done = replay.samples$done[i],
info = list(
episode = replay.samples$episode[i],
stepidx = replay.samples$stepidx[i],
info = replay.samples$info[i]
)
))
}
),
private = list(),
active = list()
)


ReplayMemUniformDB = R6::R6Class("ReplayMemUniformDB",
inherit = ReplayMemDB,
public = list(
sample.fun = function(k) {
k = min(k, self$len)
self$replayed.idx = sample(self$len)[1L:k]
list.res = self$getSamples(self$replayed.idx)
return(list.res)
}
),
private = list(),
active = list()
)


ReplayMemPng = R6::R6Class(
"ReplayMemPng",
inherit = ReplayMemUniform,
public = list(
len = NULL,
replayed.idx = NULL,
conf = NULL,
agent = NULL,
table.name = NULL,
initialize = function(agent, conf) {
super$initialize(agent, conf)
},

mkInst = function(state.old, action, reward, state.new, done, info) {
# transform/compress states into single string for DB entry
if (length(self$agent$state_dim) == 1) {
state.old %<>% paste(collapse = "_")
state.new %<>% paste(collapse = "_")
} else {
state.old = (state.old / 255L) %>% (png::writePNG) %>% paste(collapse = "")
state.new = (state.new / 255L) %>% (png::writePNG) %>% paste(collapse = "")
}
super$mkInst(state.old, action, reward, state.new, done, info)
},

sample.fun = function(k) {
k = min(k, self$size)
#FIXME: the replayed.idx are not natural index, but just the position in the replay memory
self$replayed.idx = sample(self$size)[1L:k]
list.res = lapply(self$replayed.idx, function(x) self$samples[[x]])
replay.samples = list.res
# replay.samples now are the results from the query

#FIXME: IS THE Orientation of the array right! Critically Important
list.replay = lapply(replay.samples, function(x) list(
state.old = x$state.old %>% str_to_array_h %>% array(dim = self$agent$state_dim),
action = x$action,
reward = x$reward,
state.new = x$state.new %>% str_to_array_h %>% array(dim = self$agent$state_dim),
done = x$done,
info = list(
episode = x$episode,
stepidx = x$stepidx,
info = x$info
)
))
list.replay # DEBUG: self$agent$env$showImage(list.replay[[64]][["state.new"]]) make sense
}
),
private = list(),
active = list()
)



change_storage = function(y) {
storage.mode(y) = "integer" # change storage type to integer to save space
y
}

str_to_array_h = function(string) {
(
# magittr require ()
string %>%
strsplit("") %>% # ABEF39 SPLIT into c("A", "B", "E", ...)
.[[1]] %>% # return of split is a list
(function(x) paste0(x[c(TRUE, FALSE)], x[c(FALSE, TRUE)])) %>% #combine to pairs, equivalent to zip: x[c(TRUE, FALSE)] takes the 1st,3st,5st and x[c(FALSE, TRUE)] take the 2st, 4st
as.hexmode %>% # necessary for correct as.raw. For R to understand this is hexcode other than String.
as.raw %>% # make it readable as PNG
(png::readPNG) * 255 # png package assums image to have range 0-1
) %>%
change_storage # float storage to int storage
}
11 changes: 11 additions & 0 deletions tests/testthat/test_file_replay_mem_png.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
context("replay_mem")
test_that("test basic replay_mem works", {
env = makeGymEnv("Pong-v0")
env$overview()
conf = getDefaultConf("AgentDQN")
conf$show()
conf$set(render = T, console = T, replay.memname = "Png")
agent = makeAgent("AgentDQN", env, conf)
perf = agent$learn(1L)
agent$plotPerf()
})