-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
114 lines (80 loc) · 3.09 KB
/
train.py
File metadata and controls
114 lines (80 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import tictactoe as rules
from game import transitions
from mdp import MDP
from agent import QLearner
class AgentWrapper:
def __init__(self, agent):
self.agent = agent
def __repr__(self):
return self.__class__.__name__ + "(" + self.agent.__repr__() + ")"
def __hash__(self):
return hash(self.__class__) ^ hash(self.agent)
def __eq__(self, other):
return type(self) == type(other) and self.agent == other.agent
def move(self, move):
raise NotImplementedError("AgentWrapper.move is abstract!")
class TrainingWrapper(AgentWrapper):
def update(self, *args):
self.agent.update(*args)
def move(self, state):
return self.agent.choose_next_action(state)
class TestingWrapper(AgentWrapper):
def update(self, *args):
pass
def move(self, state):
return self.agent.optimal_action(state)
def run_for_episodes(mdp, episodes, players, debug=False):
scores = {p : 0 for p in players}
for ep in range(episodes):
for transition in transitions(mdp, players):
transition.player.update(transition.prev_state,
transition.move,
transition.post_opp_move_state,
transition.reward)
scores[transition.player] += transition.reward
# switch who starts
players = (*players[1:], players[0])
if debug:
print(scores)
return scores
def epoch(mdp, training_episodes, testing_episodes, champion, contender, debug = False):
champion = TestingWrapper(champion)
contender = TrainingWrapper(contender)
# train
_ = run_for_episodes(mdp, training_episodes, (champion, contender))
# evaluate
champion = TestingWrapper(champion.agent)
contender = TestingWrapper(contender.agent)
return run_for_episodes(mdp, training_episodes, (champion, contender), debug)
def best_agent(results):
ret = max(results, key = lambda k: results[k])
if isinstance(ret, AgentWrapper):
return ret.agent
return ret
if __name__ == '__main__':
num_epochs = 5000
display_step = 50
num_training_episodes = 100
num_testing_episodes = 10
mdp = MDP(rules)
champion = QLearner()
contender = QLearner(epsilon=0.5)
for itr in range(num_epochs):
results = epoch(mdp,
num_training_episodes,
num_testing_episodes,
champion,
contender)
if (itr % display_step) == 0:
champion.save('champion.p')
print(itr, results)
champion = best_agent(results)
contender = champion.copy()
with open('table.txt', 'w') as q_table_out:
positions = {p:[] for p, m in champion.q_table}
for p, m in champion.q_table:
positions[p].append((m, champion.q_table[(p, m)]))
for p, a in positions.items():
q_table_out.write(str(p) + "\n")
for m, v in a:
q_table_out.write("\t" + str(m) + " ->" + str(v) + "\n")