-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory.py
More file actions
106 lines (92 loc) · 4.54 KB
/
memory.py
File metadata and controls
106 lines (92 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
class ReplayMemory:
"""
A ReplayMemory saves a maximum of maxlen experiences and provides methods to store and retrieve them.
"""
def __init__(self, maxlen=1000000, game_over_bias=0):
self.maxlen = maxlen
self.buf = np.empty(shape=maxlen, dtype=np.object)
self.index = 0
self.length = 0
self.game_over_bias = game_over_bias # Store the game_over transition game_over_bias times more often (to have mini prioritized replay memory)
def append(self, data):
"""
Add data to the memory taking care of the maxlen size
:param data:
:return:
"""
self.buf[self.index] = data
self.length = min(self.length + 1, self.maxlen)
self.index = (self.index + 1) % self.maxlen
def sample(self, batch_size, with_replacement=True):
"""
Retrieve batch_size number of experiences
:param batch_size: number of experiences to retrieve
:param with_replacement: whether to allow drawing the same experience twice (faster!)
:return: A list of experiences
"""
if with_replacement:
indices = np.random.randint(self.length, size=batch_size)
else:
indices = np.random.permutation(self.length)[:batch_size]
return self.buf[indices]
def store_observation(self, state, action, reward, next_state, game_over):
"""
Store the observation of an experience.
:param state: the state of the experience before acting
:param action: the action taken
:param reward: the reward received after taking the action
:param next_state: the state of the experience after taking the action
:param game_over: whether the game is over after taking the action
"""
if (game_over):
reward = -1
for _ in range(self.game_over_bias):
self.append([state, action, reward, next_state, game_over])
self.append([state, action, reward, next_state, game_over])
def get_replays(self, num_plays):
"""
Get num_plays random experiences
:param num_plays: number of experiences to get
:return: (np.array of states, np.array of actions, np.array of rewards, np.array of next_states, np.array of game_overs)
"""
replays = self.sample(num_plays)
cols = [[],[],[],[],[]]
for memory in replays:
for col, value in zip(cols, memory):
col.append(value)
cols = [np.array(col) for col in cols]
return (cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1,1))
class ReplayMemoryStatic:
"""
Same as ReplayMemory to try to implement it more specific and with faster processing,
but it seems that the ReplayMemory implementation is indeed faster.
"""
def __init__(self, maxlen=1000000, image_size=(84,84), minibatch_size=32):
self._memory_state = np.zeros(shape=(maxlen, image_size[0], image_size[1], 1), dtype=np.int8)
self._memory_future_state = np.zeros(shape=(maxlen, image_size[0], image_size[1], 1), dtype=np.int8)
self._rewards = np.zeros(shape=(maxlen, 1), dtype=np.float32)
self._is_terminal = np.zeros(shape=(maxlen, 1), dtype=np.bool)
self._actions = np.zeros(shape=(maxlen, 1), dtype=np.int8)
self._mini_batch_state = np.zeros(shape=(minibatch_size, image_size[0], image_size[1], 1), dtype=np.float32)
self._mini_batch_future_state = np.zeros(shape=(minibatch_size, image_size[0], image_size[1], 1), dtype=np.float32)
self._mini_batch_size = minibatch_size
self._maxlen = maxlen
self._counter = 0
def store_observation(self, state, action, reward, future_state, is_terminal):
position = self._counter % self._maxlen
self._memory_state[position,:,:,:] = state
self._memory_future_state[position,:,:,:] = future_state
self._rewards[position] = reward
self._is_terminal[position] = is_terminal
self._actions[position] = action
self._counter += 1
def get_replays(self, num_plays):
ind = np.random.choice(self._maxlen, size=num_plays)
# Avoiding a copy action as much as possible
self._mini_batch_state[:] = self._memory_state[ind,:,:,:]
self._mini_batch_future_state[:] = self._memory_future_state[ind,:,:,:]
rewards = self._rewards[ind]
is_terminal = self._is_terminal[ind]
actions = self._actions[ind]
return self._mini_batch_state, actions, rewards, self._mini_batch_future_state, is_terminal