-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAbstractActor.py
More file actions
132 lines (116 loc) · 6.1 KB
/
AbstractActor.py
File metadata and controls
132 lines (116 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
from tensorflow.keras import Model
from SL.encoder.encoder import Encoder
from SL.decoder.decoder import Decoder
from tensorflow.keras.optimizers import Adam
from diplomacy_research.models.state_space import dict_to_flatten_board_state, dict_to_flatten_prev_orders_state, get_current_season, extract_state_proto, proto_to_prev_orders_state, extract_phase_history_proto
from diplomacy import Game
from constants.constants import INVERSE_ORDER_DICT, INT_SEASON, UNIT_POWER_RL, UNIT_POWER
from data.process import parse_rl_state
class AbstractActor(Model):
'''
The supervised learning model for the Diplomacy game
'''
def __init__(self, num_board_blocks, num_order_blocks):
'''
Initialization for Encoder Model
Args:
num_board_blocks - number of blocks for encoding the board state
num_order_blocks - number of blocks for encoding previous orders
'''
super(AbstractActor, self).__init__()
# creating encoder and decoder networks
self.encoder = Encoder(num_board_blocks, num_order_blocks)
self.decoder = Decoder()
self.optimizer = Adam(0.001)
def call(self, state_inputs, order_inputs, power_season, season_input, board_dict, power):
'''
Function to run the SL model
Keyword Args:
state_inputs - the board state inputs
order_inputs - the previous order inputs
power_season - the power and season to be used in film
season_input - the names of the seasons to be used in creating the mask
board_dict - the board state dictionary representation
Returns:
a probability distribution over valid orders
'''
# casting inputs to float32
state_inputs = tf.cast(state_inputs, tf.float32)
order_inputs = tf.cast(order_inputs, tf.float32)
enc_out = self.encoder.call(state_inputs, order_inputs, power_season)
# extracting positions and masks to use in decoder
pos_list, masks = self.decoder.create_pos_masks(state_inputs, season_input, board_dict, power)
# tf.print("MASKS:", masks, masks.shape, tf.argmax(masks, axis=1), tf.reduce_max(masks, axis=1), tf.math.count_nonzero(masks, axis=1))
dec_out = self.decoder.call(state_inputs, enc_out, pos_list, masks)
return dec_out
def loss(self, probs, labels):
'''
Function to compute the loss of the Actor
Keyword Args:
probs - the probability distribution output over the orders
labels - the labels representing the actions taken
Return:
loss for the actions taken
'''
raise NotImplementedError("Not implemented in abstract class.")
def get_orders(self, game, power_names):
"""
See diplomacy_research.players.player.Player.get_orders
:param game: Game object
:param power_names: A list of power names we are playing, or alternatively a single power name.
:return: One of the following:
1) If power_name is a string and with_draw == False (or is not set):
- A list of orders the power should play
2) If power_name is a list and with_draw == False (or is not set):
- A list of list, which contains orders for each power
3) If power_name is a string and with_draw == True:
- A tuple of 1) the list of orders for the power, 2) a boolean to accept a draw or not
4) If power_name is a list and with_draw == True:
- A list of tuples, each tuple having the list of orders and the draw boolean
"""
# num_dummies/tiling is hacky way to get around TF Strided Slice error
# that occurs when only passing in one state (e.g. batch size of 1)
num_dummies = 2
order_history = extract_phase_history_proto(game,3)
if len(order_history) == 0:
prev_orders_state = tf.zeros((1, 81, 40), dtype=tf.float32)
else:
# print(order_history)
# Getting last movement phase
for i in range(len(order_history)-1,-1,-1):
if order_history[i].name[-1] == "M":
prev_movement_phase = order_history[i]
break
else:
continue
prev_orders_state = proto_to_prev_orders_state(prev_movement_phase, game.map).flatten().tolist()
prev_orders_state = tf.reshape(prev_orders_state, (1, 81, 40))
prev_orders__state_with_dummies = tf.tile(prev_orders_state, [num_dummies, 1, 1])
board_state = dict_to_flatten_board_state(game.get_state(), game.map)
board_state = tf.reshape(board_state, (1, 81, 35))
board_state_with_dummies = tf.tile(board_state, [num_dummies, 1, 1])
season = get_current_season(extract_state_proto(game))
state = game.get_state()
year = state["name"]
board_dict = parse_rl_state(state)
orders = []
order_probs = []
for power in power_names:
print(power, year)
power_season = tf.concat([UNIT_POWER[power],INT_SEASON[season]],axis=0)
power_season = tf.expand_dims(power_season,axis=0)
power_season_with_dummies = tf.tile(power_season, [num_dummies, 1])
probs, position_list = self.call(board_state_with_dummies,
prev_orders__state_with_dummies,
power_season_with_dummies,
[year for _ in range(num_dummies)],
[board_dict for _ in range(num_dummies)],
power)
prob_no_dummies = tf.squeeze(probs,axis=1)[:,0,:]
order_ix = tf.argmax(prob_no_dummies,axis=1)
orders_list = [INVERSE_ORDER_DICT[index] for index in order_ix.numpy()]
orders_probs_list = [prob_no_dummies[i][index] for i,index in enumerate(order_ix.numpy())]
orders.append(orders_list)
order_probs.append(orders_probs_list)
return orders,order_probs