From 9fac1dacf1043705155f2d4bb18929ab1115fc04 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 05:41:56 +0200 Subject: [PATCH 01/75] I'm seriously not gonna split this TL;DR: * Remove PDO, move DQN up in the structure * railway-gen: Add cli, remove multiprocessing (it broke the generation, possibly due to duplicated seeds) * observation-utils: PyCharm-Reformat, add support for datastructrure used in eval data * railwai-utils: add generator class as generic handler, to skip first elements more efficiently and loop if needed * train: add support for different models * tree-observation: pycharm reformat + minor fixes --- railroads/README.md | 1 - src/dqn/__init__.py | 0 src/dqn/agent.py | 117 ----------------------- src/dqn/model.py | 28 ------ src/generate_railways.py | 46 +++++---- src/observation_utils.py | 50 ++++++---- src/ppo/__init__.py | 0 src/ppo/agent.py | 103 -------------------- src/ppo/model.py | 20 ---- src/railway_utils.py | 74 ++++++++++---- src/train.py | 201 +++++++++++++++++++-------------------- src/tree_observation.py | 114 ++++++++++++---------- 12 files changed, 275 insertions(+), 479 deletions(-) delete mode 100644 railroads/README.md delete mode 100644 src/dqn/__init__.py delete mode 100644 src/dqn/agent.py delete mode 100644 src/dqn/model.py delete mode 100644 src/ppo/__init__.py delete mode 100644 src/ppo/agent.py delete mode 100644 src/ppo/model.py diff --git a/railroads/README.md b/railroads/README.md deleted file mode 100644 index 5baf854..0000000 --- a/railroads/README.md +++ /dev/null @@ -1 +0,0 @@ -Generated railroads will be saved here diff --git a/src/dqn/__init__.py b/src/dqn/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/dqn/agent.py b/src/dqn/agent.py deleted file mode 100644 index 6e1e3f9..0000000 --- a/src/dqn/agent.py +++ /dev/null @@ -1,117 +0,0 @@ -import copy -import random -import pickle -import torch -import torch.nn.functional as F - -from dqn.model import QNetwork -from replay_memory import ReplayBuffer - -BUFFER_SIZE = 500_000 -BATCH_SIZE = 512 -GAMMA = 0.998 -TAU = 1e-3 -LR = 0.5e-4 -UPDATE_EVERY = 40 - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Agent: - def __init__(self, state_size, action_size, num_agents, double_dqn=True): - self.action_size = action_size - self.double_dqn = double_dqn - - # Q-Network - self.qnetwork_local = QNetwork(state_size, action_size).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - self.optimizer = torch.optim.Adam(self.qnetwork_local.parameters(), lr=LR) - - # Replay memory - self.memory = ReplayBuffer(BUFFER_SIZE) - self.num_agents = num_agents - self.t_step = 0 - - def reset(self): - self.finished = [False] * self.num_agents - - - # Decide on an action to take in the environment - - def act(self, state, eps=0.): - state = torch.from_numpy(state).float().unsqueeze(0).to(device) - self.qnetwork_local.eval() - with torch.no_grad(): - action_values = self.qnetwork_local(state) - - # Epsilon-greedy action selection - if random.random() > eps: - return torch.argmax(action_values).item() - else: return torch.randint(self.action_size, ()).item() - - - # Record the results of the agent's action and update the model - - def step(self, handle, state, action, next_state, agent_done, episode_done, collision): - if not self.finished[handle]: - if agent_done: - reward = 1 - elif collision: - reward = -5 - else: reward = -.1 - - # Save experience in replay memory - self.memory.push(state, action, reward, next_state, agent_done or episode_done) - self.finished[handle] = agent_done or episode_done - - # Perform a gradient update every UPDATE_EVERY time steps - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 20: - self.learn(*self.memory.sample(BATCH_SIZE, device)) - - - def learn(self, states, actions, rewards, next_states, dones): - self.qnetwork_local.train() - - # Get expected Q values from local model - Q_expected = self.qnetwork_local(states).gather(1, actions) - - if self.double_dqn: - Q_best_action = self.qnetwork_local(next_states).argmax(1) - Q_targets_next = self.qnetwork_target(next_states).gather(1, Q_best_action.unsqueeze(-1)) - else: Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) - - # Compute Q targets for current states - Q_targets = rewards + GAMMA * Q_targets_next * (1 - dones) - - # Compute loss and perform a gradient step - self.optimizer.zero_grad() - loss = F.mse_loss(Q_expected, Q_targets) - loss.backward() - self.optimizer.step() - - # Update the target network parameters to `tau * local.parameters() + (1 - tau) * target.parameters()` - for target_param, local_param in zip(self.qnetwork_target.parameters(), self.qnetwork_local.parameters()): - target_param.data.copy_(TAU * local_param.data + (1.0 - TAU) * target_param.data) - - - # Checkpointing methods - - def save(self, path, *data): - torch.save(self.qnetwork_local.state_dict(), path / 'dqn/model_checkpoint.local') - torch.save(self.qnetwork_target.state_dict(), path / 'dqn/model_checkpoint.target') - torch.save(self.optimizer.state_dict(), path / 'dqn/model_checkpoint.optimizer') - with open(path / 'dqn/model_checkpoint.meta', 'wb') as file: - pickle.dump(data, file) - - def load(self, path, *defaults): - try: - print("Loading model from checkpoint...") - self.qnetwork_local.load_state_dict(torch.load(path / 'dqn/model_checkpoint.local')) - self.qnetwork_target.load_state_dict(torch.load(path / 'dqn/model_checkpoint.target')) - self.optimizer.load_state_dict(torch.load(path / 'dqn/model_checkpoint.optimizer')) - with open(path / 'dqn/model_checkpoint.meta', 'rb') as file: - return pickle.load(file) - except: - print("No checkpoint file was found") - return defaults diff --git a/src/dqn/model.py b/src/dqn/model.py deleted file mode 100644 index f64d641..0000000 --- a/src/dqn/model.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class QNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128): - super(QNetwork, self).__init__() - - self.fc1_val = nn.Linear(state_size, hidsize1) - self.fc2_val = nn.Linear(hidsize1, hidsize2) - self.fc3_val = nn.Linear(hidsize2, 1) - - self.fc1_adv = nn.Linear(state_size, hidsize1) - self.fc2_adv = nn.Linear(hidsize1, hidsize2) - self.fc3_adv = nn.Linear(hidsize2, action_size) - - def forward(self, x): - x = x.view(x.shape[0], -1) - - val = F.relu(self.fc1_val(x)) - val = F.relu(self.fc2_val(val)) - val = self.fc3_val(val) - - adv = F.relu(self.fc1_adv(x)) - adv = F.relu(self.fc2_adv(adv)) - adv = self.fc3_adv(adv) - - return val + adv - adv.mean() diff --git a/src/generate_railways.py b/src/generate_railways.py index 0f0a6ba..2b05355 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -1,21 +1,28 @@ -import time +import argparse +import multiprocessing import pickle -import numpy as np -from tqdm import tqdm from pathlib import Path -from flatland.envs.rail_generators import sparse_rail_generator, complex_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator, complex_schedule_generator - -from railway_utils import create_random_railways +import numpy as np +from tqdm import tqdm +try: + from .railway_utils import create_random_railways +except: + from railway_utils import create_random_railways project_root = Path(__file__).resolve().parent.parent -rail_generator, schedule_generator = create_random_railways(project_root) +parser = argparse.ArgumentParser(description="Train an agent in the flatland environment") + +parser.add_argument("--agents", type=int, default=3, help="Number of episodes to train for") +parser.add_argument("--cities", type=int, default=3, help="Number of episodes to train for") +parser.add_argument("--width", type=int, default=35, help="Decay factor for epsilon-greedy exploration") -width, height = 50, 50 -n_agents = 5 +flags = parser.parse_args() +width = height = flags.width +n_agents = flags.agents +rail_generator, schedule_generator = create_random_railways(project_root, flags.cities) # Load in any existing railways for this map size so we don't overwrite them try: @@ -28,18 +35,25 @@ rail_networks, schedules = [], [] -# Generate 10000 random railways in 100 batches of 100 -for _ in range(100): - for i in tqdm(range(100), ncols=120, leave=False): +def do(schedules: list, rail_networks: list): + for _ in range(100): map, info = rail_generator(width, height, n_agents, num_resets=0, np_random=np.random) schedule = schedule_generator(map, n_agents, info['agents_hints'], num_resets=0, np_random=np.random) rail_networks.append((map, info)) schedules.append(schedule) + return - print(f"Saving {len(rail_networks)} railways") + +manager = multiprocessing.Manager() +shared_schedules = manager.list(schedules) +shared_rail_networks = manager.list(rail_networks) +# Generate 10000 random railways in 100 batches of 100 +for _ in tqdm(range(100), ncols=120, leave=False): + do(schedules, rail_networks) with open(project_root / f'railroads/rail_networks_{n_agents}x{width}x{height}.pkl', 'wb') as file: - pickle.dump(rail_networks, file) + pickle.dump(schedules, file, protocol=4) with open(project_root / f'railroads/schedules_{n_agents}x{width}x{height}.pkl', 'wb') as file: - pickle.dump(schedules, file) + pickle.dump(rail_networks, file, protocol=4) +print(f"Saved {len(shared_rail_networks)} railways") print("Done") diff --git a/src/observation_utils.py b/src/observation_utils.py index fc7e18c..5999d73 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -1,24 +1,28 @@ import numpy as np -from tree_observation import ACTIONS -ZERO_NODE = np.array([0] * 11) # For Q-Networks -INF_DISTANCE_NODE = np.array([0] * 6 + [np.inf] + [0] * 4) # For policy networks +try: + from .tree_observation import ACTIONS +except: + from tree_observation import ACTIONS + +ZERO_NODE = np.array([0] * 11) # For Q-Networks +INF_DISTANCE_NODE = np.array([0] * 6 + [np.inf] + [0] * 4) # For policy networks # Helper function to detect collisions def is_collision(obs): return obs is not None \ - and isinstance(obs.childs['L'], float) \ - and isinstance(obs.childs['R'], float) \ - and obs.childs['F'].num_agents_opposite_direction > 0 \ - and obs.childs['F'].dist_other_agent_encountered <= 1 \ - and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_unusable_switch - # and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_to_next_branch + and isinstance(obs.childs['L'], float) \ + and isinstance(obs.childs['R'], float) \ + and obs.childs['F'].num_agents_opposite_direction > 0 \ + and obs.childs['F'].dist_other_agent_encountered <= 1 \ + and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_unusable_switch + # and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_to_next_branch # Recursively create numpy arrays for each tree node def create_tree_features(node, current_depth, max_depth, empty_node, data): - if node == -np.inf: + if node == -np.inf or node is None: num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // (4 - 1) data.extend([empty_node] * num_remaining_nodes) @@ -30,28 +34,34 @@ def create_tree_features(node, current_depth, max_depth, empty_node, data): return data + # Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False): if fixed_radius > 0: - max_obs = fixed_radius - else: max_obs = np.max(obs[np.where(obs < 1000)], initial=1) + 1 + max_obs = fixed_radius + else: + max_obs = np.max(obs[np.where(obs < 1000)], initial=1) + 1 min_obs = np.min(obs[np.where(obs >= 0)], initial=max_obs) if normalize_to_range else 0 if max_obs == min_obs: - return np.clip(obs / max_obs, clip_min, clip_max) - else: return np.clip((obs - min_obs) / np.abs(max_obs - min_obs), clip_min, clip_max) + return np.clip(obs / max_obs, clip_min, clip_max) + else: + return np.clip((obs - min_obs) / np.abs(max_obs - min_obs), clip_min, clip_max) # Normalize a tree observation def normalize_observation(tree, max_depth, zero_center=True): empty_node = ZERO_NODE if zero_center else INF_DISTANCE_NODE - data = np.concatenate(create_tree_features(tree, 0, max_depth, empty_node, [])).reshape((-1, 11)) + data = np.concatenate([create_tree_features(t, 0, max_depth, empty_node, []) for t in tree.values()] + if isinstance(tree, dict) else + create_tree_features(tree, 0, max_depth, empty_node, [])).reshape((-1, 11)) - obs_data = norm_obs_clip(data[:,:6].flatten()) - distances = norm_obs_clip(data[:,6], normalize_to_range=True) - agent_data = np.clip(data[:,7:].flatten(), -1, 1) + obs_data = norm_obs_clip(data[:, :6].flatten()) + distances = norm_obs_clip(data[:, 6], normalize_to_range=True) + agent_data = np.clip(data[:, 7:].flatten(), -1, 1) if zero_center: - return np.concatenate((obs_data - obs_data.mean(), distances, agent_data - agent_data.mean())) - else: return np.concatenate((obs_data, distances, agent_data)) + return np.concatenate((obs_data - obs_data.mean(), distances, agent_data - agent_data.mean())) + else: + return np.concatenate((obs_data, distances, agent_data)) diff --git a/src/ppo/__init__.py b/src/ppo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ppo/agent.py b/src/ppo/agent.py deleted file mode 100644 index 3f7071c..0000000 --- a/src/ppo/agent.py +++ /dev/null @@ -1,103 +0,0 @@ -import pickle -import random -import numpy as np -import torch -from torch.distributions.categorical import Categorical - -from ppo.model import PolicyNetwork -from replay_memory import Episode, ReplayBuffer - -BUFFER_SIZE = 32_000 -BATCH_SIZE = 4096 -GAMMA = 0.8 -LR = 0.5e-4 -CLIP_FACTOR = .005 -UPDATE_EVERY = 120 - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Agent: - def __init__(self, state_size, action_size, num_agents): - self.policy = PolicyNetwork(state_size, action_size).to(device) - self.old_policy = PolicyNetwork(state_size, action_size).to(device) - self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR) - - self.episodes = [Episode() for _ in range(num_agents)] - self.memory = ReplayBuffer(BUFFER_SIZE) - self.t_step = 0 - - def reset(self): - self.finished = [False] * len(self.episodes) - - - # Decide on an action to take in the environment - - def act(self, state, eps=None): - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - return Categorical(output).sample().item() - - - # Record the results of the agent's action and update the model - - def step(self, handle, state, action, next_state, agent_done, episode_done, collision): - if not self.finished[handle]: - if agent_done: - reward = 1 - elif collision: - reward = -.5 - else: reward = 0 - - # Push experience into Episode memory - self.episodes[handle].push(state, action, reward, next_state, agent_done or episode_done) - - # When we finish the episode, discount rewards and push the experience into replay memory - if agent_done or episode_done: - self.episodes[handle].discount_rewards(GAMMA) - self.memory.push_episode(self.episodes[handle]) - self.episodes[handle].reset() - self.finished[handle] = True - - # Perform a gradient update every UPDATE_EVERY time steps - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4: - self.learn(*self.memory.sample(BATCH_SIZE, device)) - - def learn(self, states, actions, rewards, next_state, done): - self.policy.train() - - responsible_outputs = torch.gather(self.policy(states), 1, actions) - old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() - - # rewards = rewards - rewards.mean() - ratio = responsible_outputs / (old_responsible_outputs + 1e-5) - clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() - - # Compute loss and perform a gradient step - self.old_policy.load_state_dict(self.policy.state_dict()) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - - # Checkpointing methods - - def save(self, path, *data): - torch.save(self.policy.state_dict(), path / 'ppo/model_checkpoint.policy') - torch.save(self.optimizer.state_dict(), path / 'ppo/model_checkpoint.optimizer') - with open(path / 'ppo/model_checkpoint.meta', 'wb') as file: - pickle.dump(data, file) - - def load(self, path, *defaults): - try: - print("Loading model from checkpoint...") - self.policy.load_state_dict(torch.load(path / 'ppo/model_checkpoint.policy')) - self.optimizer.load_state_dict(torch.load(path / 'ppo/model_checkpoint.optimizer')) - with open(path / 'ppo/model_checkpoint.meta', 'rb') as file: - return pickle.load(file) - except: - print("No checkpoint file was found") - return defaults diff --git a/src/ppo/model.py b/src/ppo/model.py deleted file mode 100644 index 51febc2..0000000 --- a/src/ppo/model.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class PolicyNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=64, hidsize3=32): - super().__init__() - self.fc1 = nn.Linear(state_size, hidsize1) - self.fc2 = nn.Linear(hidsize1, hidsize2) - # self.fc3 = nn.Linear(hidsize2, hidsize3) - self.output = nn.Linear(hidsize2, action_size) - self.softmax = nn.Softmax(dim=1) - self.bn0 = nn.BatchNorm1d(state_size, affine=False) - - def forward(self, inputs): - x = self.bn0(inputs.float()) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - # x = F.relu(self.fc3(x)) - return self.softmax(self.output(x)) diff --git a/src/railway_utils.py b/src/railway_utils.py index 6362b9d..b17c025 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -1,30 +1,64 @@ +import os import pickle from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator +class Generator: + def __init__(self, path, start_index=0): + self.path = path + self.index = start_index + self.data = iter([]) + self.len = 0 + + self._load() + + def _load(self): + with open(self.path, 'rb') as file: + data = pickle.load(file) + self.len = len(data) + self.data = iter(data[self.index:]) + + def __next__(self): + try: + data = next(self.data) + except StopIteration: + self._load() + if self.index >= self.len: + print("[WARNING] Restarting training loop from zero") + self.index = 0 + self._load() + data = next(self) + self.index += 1 + return data + + def __len__(self): + return self.len + + def __call__(self, *args, **kwargs): + return next(self) + + # Helper function to load in precomputed railway networks -def load_precomputed_railways(project_root, flags): - with open(project_root / f'railroads/rail_networks_{flags.num_agents}x{flags.grid_width}x{flags.grid_height}.pkl', 'rb') as file: - data = pickle.load(file) - rail_networks = iter(data) - print(f"Loading {len(data)} railways...") - with open(project_root / f'railroads/schedules_{flags.num_agents}x{flags.grid_width}x{flags.grid_height}.pkl', 'rb') as file: - schedules = iter(pickle.load(file)) - - rail_generator = lambda *args: next(rail_networks) - schedule_generator = lambda *args: next(schedules) - return rail_generator, schedule_generator +def load_precomputed_railways(project_root, start_index): + prefix = os.path.join(project_root, 'railroads') + suffix = f'_sum.pkl' + rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + print(f"Working on {len(rail)} tracks") + return rail, sched + # Helper function to generate railways on the fly -def create_random_railways(project_root): - speed_ration_map = { - 1 / 1: 1.0, # Fast passenger train - 1 / 2.: 0.0, # Fast freight train - 1 / 3.: 0.0, # Slow commuter train - 1 / 4.: 0.0 } # Slow freight train - - rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=3, max_rails_between_cities=2, max_rails_in_city=3) - schedule_generator = sparse_schedule_generator(speed_ration_map) +def create_random_railways(project_root, max_cities=5): + speed_ratio_map = { + 1 / 1: 1.0, # Fast passenger train + 1 / 2.: 0.0, # Fast freight train + 1 / 3.: 0.0, # Slow commuter train + 1 / 4.: 0.0} # Slow freight train + + rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=max_cities, + max_rails_between_cities=max_cities - 1, max_rails_in_city=max_cities - 1) + schedule_generator = sparse_schedule_generator(speed_ratio_map) return rail_generator, schedule_generator diff --git a/src/train.py b/src/train.py index d52112b..2ffc0f4 100644 --- a/src/train.py +++ b/src/train.py @@ -1,21 +1,25 @@ -import cv2 -import time import argparse -import numpy as np +import copy +import time from pathlib import Path -from collections import deque -from tensorboardX import SummaryWriter -from flatland.envs.rail_env import RailEnv, RailEnvActions +import cv2 +import numpy as np from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool, AgentRenderVariant +from tensorboardX import SummaryWriter -from dqn.agent import Agent as DQN_Agent -from ppo.agent import Agent as PPO_Agent -from tree_observation import TreeObservation -from observation_utils import normalize_observation, is_collision -from railway_utils import load_precomputed_railways, create_random_railways - +try: + from .agent import Agent as DQN_Agent + from .tree_observation import TreeObservation + from .observation_utils import normalize_observation, is_collision + from .railway_utils import load_precomputed_railways, create_random_railways +except: + from agent import Agent as DQN_Agent + from tree_observation import TreeObservation + from observation_utils import normalize_observation, is_collision + from railway_utils import load_precomputed_railways, create_random_railways project_root = Path(__file__).resolve().parent.parent parser = argparse.ArgumentParser(description="Train an agent in the flatland environment") @@ -23,8 +27,10 @@ # Task parameters parser.add_argument("--train", type=boolean, default=True, help="Whether to train the model or just evaluate it") -parser.add_argument("--load-model", default=False, action='store_true', help="Whether to load the model from the last checkpoint") -parser.add_argument("--load-railways", type=boolean, default=True, help="Whether to load in pre-generated railway networks") +parser.add_argument("--load-model", default=False, action='store_true', + help="Whether to load the model from the last checkpoint") +parser.add_argument("--load-railways", type=boolean, default=True, + help="Whether to load in pre-generated railway networks") parser.add_argument("--report-interval", type=int, default=100, help="Iterations between reports") parser.add_argument("--render-interval", type=int, default=0, help="Iterations between renders") @@ -33,25 +39,38 @@ parser.add_argument("--grid-height", type=int, default=50, help="Number of rows in the environment grid") parser.add_argument("--num-agents", type=int, default=5, help="Number of agents in each episode") parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--model-depth", type=int, default=4, help="Depth of the observation tree") +parser.add_argument("--hidden-factor", type=int, default=15, help="Depth of the observation tree") # Training parameters -parser.add_argument("--agent-type", default="ppo", choices=["dqn", "ppo"], help="Which type of RL agent to use") -parser.add_argument("--num-episodes", type=int, default=10000, help="Number of episodes to train for") -parser.add_argument("--epsilon-decay", type=float, default=0.997, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") +parser.add_argument("--num-episodes", type=int, default=10**6, help="Number of episodes to train for") +parser.add_argument("--epsilon-decay", type=float, default=0.999, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") flags = parser.parse_args() - # Seeded RNG so we can replicate our results np.random.seed(1) # Create a tensorboard SummaryWriter summary = SummaryWriter(f'tensorboard/dqn/agents: {flags.num_agents}, tree_depth: {flags.tree_depth}') - +# Calculate the state size based on the number of nodes in the tree observation +num_features_per_node = 11 # env.obs_builder.observation_dim +num_nodes = sum(np.power(4, i) for i in range(flags.tree_depth + 1)) +state_size = num_nodes * num_features_per_node +action_size = 5 +# Load an RL agent and initialize it from checkpoint if necessary +agent = DQN_Agent(state_size, action_size, flags.num_agents, flags.model_depth, flags.hidden_factor) +if flags.load_model: + start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) +else: + start, eps = 0, 1.0 # We need to either load in some pre-generated railways from disk, or else create a random railway generator. if flags.load_railways: - rail_generator, schedule_generator = load_precomputed_railways(project_root, flags) -else: rail_generator, schedule_generator = create_random_railways(project_root) + rail_generator, schedule_generator = load_precomputed_railways(project_root, start) +else: + rail_generator, schedule_generator = create_random_railways(project_root) # Create the Flatland environment env = RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, @@ -59,59 +78,37 @@ schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params(MalfunctionParameters(1 / 8000, 15, 50)), obs_builder_object=TreeObservation(max_depth=flags.tree_depth) -) + ) # After training we want to render the results so we also load a renderer -env_renderer = RenderTool(env, gl="PILSVG", screen_width=800, screen_height=800, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) - -# Calculate the state size based on the number of nodes in the tree observation -num_features_per_node = env.obs_builder.observation_dim -num_nodes = sum(np.power(4, i) for i in range(flags.tree_depth + 1)) -state_size = num_nodes * num_features_per_node -action_size = 5 +env_renderer = RenderTool(env, gl="PILSVG", screen_width=800, screen_height=800, + agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) # Add some variables to keep track of the progress -scores_window, steps_window, collisions_window, done_window = [deque(maxlen=200) for _ in range(4)] -agent_obs = [None] * flags.num_agents -agent_obs_buffer = [None] * flags.num_agents -agent_action_buffer = [2] * flags.num_agents -max_steps = 8 * (flags.grid_width + flags.grid_height) -start_time = time.time() - -# Load an RL agent and initialize it from checkpoint if necessary -if flags.agent_type == "dqn": - agent = DQN_Agent(state_size, action_size, flags.num_agents) -elif flags.agent_type == "ppo": - agent = PPO_Agent(state_size, action_size, flags.num_agents) +current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = 0 -if flags.load_model: - start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) -else: start, eps = 0, 1.0 +agent_action_buffer = [] +start_time = time.time() if not flags.train: eps = 0.0 -# We don't want to retrain on old railway networks when we restart from a checkpoint, so we just loop -# through the generators to get all the old networks out of the way -if start > 0: print(f"Skipping {start} railways") -for _ in range(0, start): - rail_generator() - schedule_generator() - - # Helper function to detect collisions -ACTIONS = { 0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S' } +ACTIONS = {0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S'} + def is_collision(a): if obs[a] is None: return False is_junction = not isinstance(obs[a].childs['L'], float) or not isinstance(obs[a].childs['R'], float) if not is_junction or env.agents[a].speed_data['position_fraction'] > 0: - action = ACTIONS[env.agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' - return obs[a].childs[action].num_agents_opposite_direction > 0 \ - and obs[a].childs[action].dist_other_agent_encountered <= 1 \ - and obs[a].childs[action].dist_other_agent_encountered < obs[a].childs[action].dist_unusable_switch - else: return False + action = ACTIONS[env.agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' + return obs[a].childs[action].num_agents_opposite_direction > 0 \ + and obs[a].childs[action].dist_other_agent_encountered <= 1 \ + and obs[a].childs[action].dist_other_agent_encountered < obs[a].childs[action].dist_unusable_switch + else: + return False + # Helper function to render the environment def render(): @@ -119,21 +116,13 @@ def render(): cv2.imshow('Render', cv2.cvtColor(env_renderer.get_image(), cv2.COLOR_BGR2RGB)) cv2.waitKey(120) -# Helper function to generate a report -def get_report(show_time=False): - training = 'Training' if flags.train else 'Evaluating' - return ' | '.join(filter(None, [ - f'\r{training} {flags.num_agents} Agents on {flags.grid_width} x {flags.grid_height} Map', - f'Episode {episode:<5}', - f'Average Score: {np.mean(scores_window):.3f}', - f'Average Steps Taken: {np.mean(steps_window):<6.1f}', - f'Collisions: {100 * np.mean(collisions_window):>5.2f}%', - f'Finished: {100 * np.mean(done_window):>6.2f}%', - f'Epsilon: {eps:.2f}' if flags.agent_type == "dqn" else None, - f'Time taken: {time.time() - start_time:.2f}s' if show_time else None])) + ' ' +def get_means(x, y, c, s): + return (x * 3 + c) / 4, (y * (s - 1) + c) / s +episode = 0 + # Main training loop for episode in range(start + 1, flags.num_episodes + 1): agent.reset() @@ -141,27 +130,30 @@ def get_report(show_time=False): obs, info = env.reset(True, True) score, steps_taken, collision = 0, 0, False - # Build initial observations for each agent - for a in range(flags.num_agents): - agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=flags.agent_type == 'dqn') - agent_obs_buffer[a] = agent_obs[a].copy() + agent_obs = [normalize_observation(obs[a], flags.tree_depth, zero_center=True) + for a in obs.keys()] + agent_obs_buffer = copy.deepcopy(agent_obs) + agent_count = len(agent_obs) + agent_action_buffer = [2] * agent_count # Run an episode + max_steps = 8 * (env.width + env.height) for step in range(max_steps): - update_values = [False] * flags.num_agents + update_values = [False] * agent_count action_dict = {} - for a in range(flags.num_agents): + for a in range(agent_count): if info['action_required'][a]: - action_dict[a] = agent.act(agent_obs[a], eps=eps) - # action_dict[a] = np.random.randint(5) - update_values[a] = True - steps_taken += 1 - else: action_dict[a] = 0 + action_dict[a] = agent.act(agent_obs[a], eps=eps) + # action_dict[a] = np.random.randint(5) + update_values[a] = True + steps_taken += 1 + else: + action_dict[a] = 0 # Environment step obs, rewards, done, info = env.step(action_dict) - score += sum(rewards.values()) / flags.num_agents + score += sum(rewards.values()) / agent_count # Check for collisions and episode completion if step == max_steps - 1: @@ -171,14 +163,15 @@ def get_report(show_time=False): # done['__all__'] = True # Update replay buffer and train agent - for a in range(flags.num_agents): + for a in range(agent_count): if flags.train and (update_values[a] or done[a] or done['__all__']): - agent.step(a, agent_obs_buffer[a], agent_action_buffer[a], agent_obs[a], done[a], done['__all__'], is_collision(a)) + agent.step(a, agent_obs_buffer[a], agent_action_buffer[a], agent_obs[a], done[a], done['__all__'], + is_collision(a), flags.step_reward) agent_obs_buffer[a] = agent_obs[a].copy() agent_action_buffer[a] = action_dict[a] if obs[a]: - agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=flags.agent_type == 'dqn') + agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=True) # Render # if flags.render_interval and episode % flags.render_interval == 0: @@ -189,24 +182,26 @@ def get_report(show_time=False): if done['__all__']: break # Epsilon decay - if flags.train: eps = max(0.01, flags.epsilon_decay * eps) + if flags.train: + eps = max(0.01, flags.epsilon_decay * eps) # Save some training statistics in their respective deques - tasks_finished = sum(done[i] for i in range(flags.num_agents)) - done_window.append(tasks_finished / max(1, flags.num_agents)) - collisions_window.append(1. if collision else 0.) - scores_window.append(score / max_steps) - steps_window.append(steps_taken) - - # Generate training reports, saving our progress every so often - print(get_report(), end=" ") + tasks_finished = sum(done[i] for i in range(agent_count)) + current_done, mean_done = get_means(current_done, mean_done, tasks_finished / max(1, agent_count), episode) + current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) + current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) + current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken, episode) + + print(f'\rEpisode {episode:<5}' + f' | Score: {current_score:.4f}, {mean_score:.4f}' + f' | Steps: {current_steps:6.1f}, {mean_steps:6.1f}' + f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' + f' | Finished: {100 * current_done:6.2f}%, {100 * mean_done:6.2f}%' + f' | Epsilon: {eps:.2f}' + f' | Episode/s: {episode / (time.time() - start_time):.2f}s', end='') + if episode % flags.report_interval == 0: - print(get_report(show_time=True)) - start_time = time.time() - if flags.train: agent.save(project_root / 'checkpoints', episode, eps) - - # Add stats to the tensorboard summary - summary.add_scalar('performance/avg_score', np.mean(scores_window), episode) - summary.add_scalar('performance/avg_steps', np.mean(steps_window), episode) - summary.add_scalar('performance/completions', np.mean(done_window), episode) - summary.add_scalar('performance/collisions', np.mean(collisions_window), episode) + print("") + if flags.train: + agent.save(project_root / 'checkpoints', episode, eps) + # Add stats to the tensorboard summary diff --git a/src/tree_observation.py b/src/tree_observation.py index 91c6e9e..d126eb7 100644 --- a/src/tree_observation.py +++ b/src/tree_observation.py @@ -1,27 +1,30 @@ -import numpy as np from collections import defaultdict +import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv - ACTIONS = ['L', 'F', 'R', 'B'] Node = TreeObsForRailEnv.Node + def first(list): return next(iter(list)) + def get_action(orientation, direction): return ACTIONS[(direction - orientation + 1 + 4) % 4] + def get_direction(orientation, action): if action == 1: - return (orientation + 4 - 1) % 4 + return (orientation + 4 - 1) % 4 elif action == 3: - return (orientation + 1) % 4 - else: return orientation + return (orientation + 1) % 4 + else: + return orientation class RailNode: @@ -35,39 +38,42 @@ def __repr__(self): return f'RailNode({self.position}, {len(self.edges)})' - class TreeObservation(ObservationBuilder): def __init__(self, max_depth): super().__init__() self.max_depth = max_depth self.observation_dim = 11 - # Create a graph representation of the current rail network def reset(self): - self.target_positions = { agent.target: 1 for agent in self.env.agents } - self.edge_positions = defaultdict(list) # (cell.position, direction) -> [(start, end, direction, distance)] - self.edge_paths = defaultdict(list) # (node.position, direction) -> [(cell.position, direction)] + self.target_positions = {agent.target: 1 for agent in self.env.agents} + self.edge_positions = defaultdict(list) # (cell.position, direction) -> [(start, end, direction, distance)] + self.edge_paths = defaultdict(list) # (node.position, direction) -> [(cell.position, direction)] # First, we find a node by starting at one of the agents and following the rails until we reach a junction agent = first(self.env.agents) position = agent.initial_position direction = agent.direction - while not self.is_junction(position) and not self.is_target(position): + while True: + try: + out = self.is_junction(position) or self.is_target(position) + except IndexError: + break + if not out: + break direction = first(self.get_possible_transitions(position, direction)) position = get_new_position(position, direction) # Now we create a graph representation of the rail network, starting from this node transitions = self.get_all_transitions(position) - root_nodes = { t: RailNode(position, t, self.is_target(position)) for t in transitions if t } - self.graph = { (*position, d): root_nodes[t] for d, t in enumerate(transitions) if t } + root_nodes = {t: RailNode(position, t, self.is_target(position)) for t in transitions if t} + self.graph = {(*position, d): root_nodes[t] for d, t in enumerate(transitions) if t} for transitions, node in root_nodes.items(): for direction in transitions: self.explore_branch(node, get_new_position(position, direction), direction) - def explore_branch(self, node, position, direction): original_direction = direction edge_positions = {} @@ -83,9 +89,9 @@ def explore_branch(self, node, position, direction): # Create any nodes that aren't in the graph yet transitions = self.get_all_transitions(position) - nodes = { t: RailNode(position, t, self.is_target(position)) - for d, t in enumerate(transitions) - if t and (*position, d) not in self.graph } + nodes = {t: RailNode(position, t, self.is_target(position)) + for d, t in enumerate(transitions) + if t and (*position, d) not in self.graph} for d, t in enumerate(transitions): if t in nodes: @@ -103,14 +109,13 @@ def explore_branch(self, node, position, direction): for direction in transitions: self.explore_branch(node, get_new_position(position, direction), direction) - # Create a tree observation for each agent, based on the graph we created earlier - def get_many(self, handles = []): - self.nodes_with_agents_going, self.edges_with_agents_going = {}, defaultdict(dict) + def get_many(self, handles=[]): + self.nodes_with_agents_going, self.edges_with_agents_going = {}, defaultdict(dict) self.nodes_with_agents_coming, self.edges_with_agents_coming = {}, defaultdict(dict) - self.nodes_with_malfunctions, self.edges_with_malfunctions = {}, defaultdict(dict) - self.nodes_with_departures, self.edges_with_departures = {}, defaultdict(dict) + self.nodes_with_malfunctions, self.edges_with_malfunctions = {}, defaultdict(dict) + self.nodes_with_departures, self.edges_with_departures = {}, defaultdict(dict) # Create some lookup tables that we can use later to figure out how far away the agents are from each other. for agent in self.env.agents: @@ -119,7 +124,8 @@ def get_many(self, handles = []): if (*agent.initial_position, direction) in self.graph: self.nodes_with_departures[(*agent.initial_position, direction)] = 1 - for start, _, start_direction, distance in self.edge_positions[(*agent.initial_position, direction)]: + for start, _, start_direction, distance in self.edge_positions[ + (*agent.initial_position, direction)]: self.edges_with_departures[(*start.position, start_direction)][agent.handle] = distance if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position: @@ -150,15 +156,18 @@ def get_many(self, handles = []): coming_direction = (exit_direction + 2) % 4 edge_dict = self.edges_with_agents_coming if direction == coming_direction else self.edges_with_agents_going if direction == agent.direction or direction == coming_direction: - for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: - edge_distance = distance if direction == agent.direction else start.edges[start_direction][1] - distance - edge_dict[(*start.position, start_direction)][agent.handle] = (distance, agent.speed_data['speed']) - + for start, _, start_direction, distance in self.edge_positions[ + (*agent.position, direction)]: + edge_distance = distance if direction == agent.direction else \ + start.edges[start_direction][1] - distance + edge_dict[(*start.position, start_direction)][agent.handle] = ( + distance, agent.speed_data['speed']) # Check for malfunctions if agent.malfunction_data['malfunction']: if (*agent.position, direction) in self.graph: - self.nodes_with_malfunctions[(*agent.position, direction)] = agent.malfunction_data['malfunction'] + self.nodes_with_malfunctions[(*agent.position, direction)] = agent.malfunction_data[ + 'malfunction'] for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ @@ -166,39 +175,40 @@ def get_many(self, handles = []): return super().get_many(handles) - # Compute the observation for a single agent def get(self, handle): agent = self.env.agents[handle] visited_cells = set() if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_position = agent.initial_position + agent_position = agent.initial_position elif agent.status == RailAgentStatus.ACTIVE: - agent_position = agent.position + agent_position = agent.position elif agent.status == RailAgentStatus.DONE: - agent_position = agent.target - else: return None + agent_position = agent.target + else: + return None # The root node contains information about the agent itself - children = { x: -np.inf for x in ACTIONS } + children = {x: -np.inf for x in ACTIONS} dist_min_to_target = self.env.distance_map.get()[(handle, *agent_position, agent.direction)] agent_malfunctioning, agent_speed = agent.malfunction_data['malfunction'], agent.speed_data['speed'] - root_tree_node = Node(0, 0, 0, 0, 0, 0, dist_min_to_target, 0, 0, agent_malfunctioning, agent_speed, 0, children) + root_tree_node = Node(0, 0, 0, 0, 0, 0, dist_min_to_target, 0, 0, agent_malfunctioning, agent_speed, 0, + children) # Then we build out the tree by exploring from this node key = (*agent_position, agent.direction) - if key in self.graph: # If we're sitting on a junction, branch out immediately + if key in self.graph: # If we're sitting on a junction, branch out immediately node = self.graph[key] - if len(node.edges) > 1: # Major node + if len(node.edges) > 1: # Major node for direction in self.graph[key].edges.keys(): root_tree_node.childs[get_action(agent.direction, direction)] = \ self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) - else: # Minor node + else: # Minor node direction = first(self.get_possible_transitions(node.position, agent.direction)) root_tree_node.childs['F'] = self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) - else: # Just create a single child in the forward direction + else: # Just create a single child in the forward direction prev_node, next_node, direction, distance = first(self.edge_positions[key]) root_tree_node.childs['F'] = self.get_tree_branch(agent, prev_node, direction, visited_cells, -distance, 1) @@ -206,7 +216,6 @@ def get(self, handle): return root_tree_node - # Get the next tree node, starting from `node`, facing `orientation`, and moving in `direction`. def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, depth): visited_cells.add((*node.position, 0)) @@ -269,28 +278,32 @@ def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, # Check for target nodes up ahead if next_node.is_target: if self.is_own_target(agent, next_node): - distance_to_own_target = min(distance_to_own_target, edge_length + distance) - else: distance_to_other_target = min(distance_to_other_target, edge_length + distance) + distance_to_own_target = min(distance_to_own_target, edge_length + distance) + else: + distance_to_other_target = min(distance_to_other_target, edge_length + distance) # Move on to the next node node = next_node edge_length += distance - if len(node.edges) == 1 and not self.is_own_target(agent, node): # This is a minor node, keep exploring - direction, (next_node, distance) = first(node.edges.items()) - if not node.is_target: - distance_to_minor_node = min(distance_to_minor_node, edge_length) - else: break + if len(node.edges) == 1 and not self.is_own_target(agent, node): # This is a minor node, keep exploring + direction, (next_node, distance) = first(node.edges.items()) + if not node.is_target: + distance_to_minor_node = min(distance_to_minor_node, edge_length) + else: + break # Create a new tree node and populate its children if depth < self.max_depth: - children = { x: -np.inf for x in ACTIONS } + children = {x: -np.inf for x in ACTIONS} if not self.is_own_target(agent, node): for direction in node.edges.keys(): children[get_action(orientation, direction)] = \ - self.get_tree_branch(agent, node, direction, visited_cells, total_distance + edge_length, depth + 1) + self.get_tree_branch(agent, node, direction, visited_cells, total_distance + edge_length, + depth + 1) - else: children = {} + else: + children = {} return Node(dist_own_target_encountered=total_distance + distance_to_own_target, dist_other_target_encountered=total_distance + distance_to_other_target, @@ -306,7 +319,6 @@ def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, num_agents_ready_to_depart=num_agent_departures, childs=children) - # Helper functions def get_possible_transitions(self, position, direction): From 4870b925fbe7c35e67b89b35342c4bb54431e3ee Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 05:42:29 +0200 Subject: [PATCH 02/75] feat: add concatenation script for railways --- railroads2/cat.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 railroads2/cat.py diff --git a/railroads2/cat.py b/railroads2/cat.py new file mode 100644 index 0000000..abbf632 --- /dev/null +++ b/railroads2/cat.py @@ -0,0 +1,29 @@ +import os +import pickle +import random + + +def main(base, out_name): + files = sorted([i for i in os.listdir() if i.startswith(base) and not i.endswith('.bak') and 'sum' not in i]) + print(f'Concatenating {", ".join(files)}') + + out = [] + + for name in files: + with open(name, 'rb') as f: + try: + out.extend(pickle.load(f)) + except Exception as e: + print(f'Caught {e} while processing {name}') + + name = out_name+'sum.pkl' + random.seed(0) + random.shuffle(out) + with open(name, 'wb') as f: + pickle.dump(out, f) + print(f'Dumped {len(out)} items from {len(files)} sources to {name}') + +if __name__ == '__main__': + main('rail_networks_', 'schedules_') + main('schedules_', 'rail_networks_') + From 2c4e1c9fe7567dafee0dc58c84cfa0ae3a073a69 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 05:42:58 +0200 Subject: [PATCH 03/75] feat: add slightly improved dqn agent --- src/agent.py | 126 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 src/agent.py diff --git a/src/agent.py b/src/agent.py new file mode 100644 index 0000000..7d7acdc --- /dev/null +++ b/src/agent.py @@ -0,0 +1,126 @@ +import copy +import pickle +import random + +import torch +import torch.nn.functional as F +from torch_optimizer import Yogi as Optimizer + +try: + from .model import QNetwork + from .replay_memory import ReplayBuffer +except: + from model import QNetwork + from replay_memory import ReplayBuffer +import os + +BUFFER_SIZE = 500_000 +BATCH_SIZE = 512 +GAMMA = 0.998 +TAU = 1e-3 +LR = 3e-5 +UPDATE_EVERY = 200 + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Agent: + def __init__(self, state_size, action_size, num_agents, model_depth, hidden_factor, double_dqn=True): + self.action_size = action_size + self.double_dqn = double_dqn + + # Q-Network + self.qnetwork_local = QNetwork(state_size, action_size, hidden_factor, model_depth).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + self.optimizer = Optimizer(self.qnetwork_local.parameters(), lr=LR, weight_decay=1e-2) + + # Replay memory + self.memory = ReplayBuffer(BUFFER_SIZE) + self.t_step = 0 + + def reset(self): + self.finished = [False] * 40 # Up to 40 agents used + + # Decide on an action to take in the environment + + def act(self, state, eps=0.): + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + + # Epsilon-greedy action selection + if random.random() > eps: + return torch.argmax(action_values).item() + else: + return torch.randint(self.action_size, ()).item() + + # Record the results of the agent's action and update the model + + def step(self, handle, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): + if not self.finished[handle]: + if agent_done: + reward = 1 + elif collision: + reward = -5 + else: + reward = step_reward + + # Save experience in replay memory + self.memory.push(state, action, reward, next_state, agent_done or episode_done) + self.finished[handle] = agent_done or episode_done + + # Perform a gradient update every UPDATE_EVERY time steps + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 20: + self.learn(*self.memory.sample(BATCH_SIZE, device)) + + def learn(self, states, actions, rewards, next_states, dones): + self.qnetwork_local.train() + + # Get expected Q values from local model + Q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + Q_best_action = self.qnetwork_local(next_states).argmax(1) + Q_targets_next = self.qnetwork_target(next_states).gather(1, Q_best_action.unsqueeze(-1)) + else: + Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + Q_targets = rewards + GAMMA * Q_targets_next * (1 - dones) + + # Compute loss and perform a gradient step + self.optimizer.zero_grad() + loss = F.mse_loss(Q_expected, Q_targets) + loss.backward() + self.optimizer.step() + + # Update the target network parameters to `tau * local.parameters() + (1 - tau) * target.parameters()` + for target_param, local_param in zip(self.qnetwork_target.parameters(), self.qnetwork_local.parameters()): + target_param.data.copy_(TAU * local_param.data + (1.0 - TAU) * target_param.data) + + # Checkpointing methods + + def save(self, path, *data): + torch.save(self.qnetwork_local.state_dict(), path / 'dqn/model_checkpoint.local') + torch.save(self.qnetwork_target.state_dict(), path / 'dqn/model_checkpoint.target') + torch.save(self.optimizer.state_dict(), path / 'dqn/model_checkpoint.optimizer') + with open(path / 'dqn/model_checkpoint.meta', 'wb') as file: + pickle.dump(data, file) + + def load(self, path, *defaults): + loc = {} if torch.cuda.is_available() else {'map_location': torch.device('cpu')} + try: + print("Loading model from checkpoint...") + dqn = os.path.join(path, 'dqn') + self.qnetwork_local.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.local'), **loc)) + self.qnetwork_target.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.target'), **loc)) + self.optimizer.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.optimizer'), **loc)) + with open(os.path.join(dqn, 'model_checkpoint.meta'), 'rb') as file: + return pickle.load(file) + except Exception as exc: + import traceback + traceback.print_exc() + print(f"Got exception {exc} loading model data. Possibly no checkpoint found.") + return defaults From 9d8a8fc187b54341fb140c9725512c4ace1c8d13 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 09:50:38 +0200 Subject: [PATCH 04/75] perf: use torch only (no numpy) --- src/agent.py | 16 +++++++--------- src/observation_utils.py | 41 ++++++++++++++++++++++++++-------------- src/railway_utils.py | 6 +++--- src/replay_memory.py | 24 +++++++++++++---------- src/train.py | 22 ++++++++++++++------- 5 files changed, 66 insertions(+), 43 deletions(-) diff --git a/src/agent.py b/src/agent.py index 7d7acdc..5969572 100644 --- a/src/agent.py +++ b/src/agent.py @@ -19,7 +19,7 @@ GAMMA = 0.998 TAU = 1e-3 LR = 3e-5 -UPDATE_EVERY = 200 +UPDATE_EVERY = 1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -44,10 +44,10 @@ def reset(self): # Decide on an action to take in the environment def act(self, state, eps=0.): - state = torch.from_numpy(state).float().unsqueeze(0).to(device) + state = state.unsqueeze(0).to(device) self.qnetwork_local.eval() with torch.no_grad(): - action_values = self.qnetwork_local(state) + action_values = self.qnetwork_local(state)[0] # Epsilon-greedy action selection if random.random() > eps: @@ -79,13 +79,11 @@ def learn(self, states, actions, rewards, next_states, dones): self.qnetwork_local.train() # Get expected Q values from local model - Q_expected = self.qnetwork_local(states).gather(1, actions) + Q_expected = self.qnetwork_local(states) + Q_expected = Q_expected.gather(1, actions) - if self.double_dqn: - Q_best_action = self.qnetwork_local(next_states).argmax(1) - Q_targets_next = self.qnetwork_target(next_states).gather(1, Q_best_action.unsqueeze(-1)) - else: - Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + Q_best_action = self.qnetwork_local(next_states).argmax(1) + Q_targets_next = self.qnetwork_target(next_states).gather(1, Q_best_action.unsqueeze(-1)) # Compute Q targets for current states Q_targets = rewards + GAMMA * Q_targets_next * (1 - dones) diff --git a/src/observation_utils.py b/src/observation_utils.py index 5999d73..b146068 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -1,4 +1,5 @@ import numpy as np +import torch try: from .tree_observation import ACTIONS @@ -35,19 +36,29 @@ def create_tree_features(node, current_depth, max_depth, empty_node, data): return data +TRUE = torch.ones(1) +FALSE = torch.zeros(1) + + # Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features -def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False): - if fixed_radius > 0: - max_obs = fixed_radius - else: - max_obs = np.max(obs[np.where(obs < 1000)], initial=1) + 1 +#@torch.jit.script +def norm_obs_clip(obs, normalize_to_range): + max_obs = obs[obs < 1000].max() + max_obs.clamp_(min=1) + max_obs.add_(1) - min_obs = np.min(obs[np.where(obs >= 0)], initial=max_obs) if normalize_to_range else 0 + min_obs = torch.zeros(1)[0] + + if normalize_to_range.item(): + min_obs.add_(obs[obs >= 0].min().clamp(max=max_obs.item())) if max_obs == min_obs: - return np.clip(obs / max_obs, clip_min, clip_max) + obs.div_(max_obs) else: - return np.clip((obs - min_obs) / np.abs(max_obs - min_obs), clip_min, clip_max) + obs.sub_(min_obs) + max_obs.sub_(min_obs) + obs.div_(max_obs) + return obs # Normalize a tree observation @@ -56,12 +67,14 @@ def normalize_observation(tree, max_depth, zero_center=True): data = np.concatenate([create_tree_features(t, 0, max_depth, empty_node, []) for t in tree.values()] if isinstance(tree, dict) else create_tree_features(tree, 0, max_depth, empty_node, [])).reshape((-1, 11)) + data = torch.as_tensor(data).float() - obs_data = norm_obs_clip(data[:, :6].flatten()) - distances = norm_obs_clip(data[:, 6], normalize_to_range=True) - agent_data = np.clip(data[:, 7:].flatten(), -1, 1) + norm_obs_clip(data[:, :6], FALSE) + norm_obs_clip(data[:, 6], TRUE) + data.clamp_(-1, 1) if zero_center: - return np.concatenate((obs_data - obs_data.mean(), distances, agent_data - agent_data.mean())) - else: - return np.concatenate((obs_data, distances, agent_data)) + data[:, :6].sub_(data[:, :6].mean()) + data[:, 7:].sub_(data[:, 7:].mean()) + + return data.view(-1) diff --git a/src/railway_utils.py b/src/railway_utils.py index b17c025..a533fa2 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -43,9 +43,9 @@ def __call__(self, *args, **kwargs): # Helper function to load in precomputed railway networks def load_precomputed_railways(project_root, start_index): prefix = os.path.join(project_root, 'railroads') - suffix = f'_sum.pkl' - rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) - sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + suffix = f'_3x30x30.pkl' + sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) print(f"Working on {len(rail)} tracks") return rail, sched diff --git a/src/replay_memory.py b/src/replay_memory.py index 61a1b81..461e870 100644 --- a/src/replay_memory.py +++ b/src/replay_memory.py @@ -1,8 +1,8 @@ -import torch import random -import numpy as np from collections import namedtuple, deque, Iterable +import numpy as np +import torch Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) @@ -28,7 +28,7 @@ def __init__(self, buffer_size): self.memory = deque(maxlen=buffer_size) def push(self, state, action, reward, next_state, done): - self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)) + self.memory.append(Transition(state.unsqueeze(0), action, reward, next_state.unsqueeze(0), done)) def push_episode(self, episode): for step in episode.memory: @@ -37,17 +37,21 @@ def push_episode(self, episode): def sample(self, batch_size, device): experiences = random.sample(self.memory, k=batch_size) - states = torch.from_numpy(self.stack([e.state for e in experiences])).float().to(device) - actions = torch.from_numpy(self.stack([e.action for e in experiences])).long().to(device) - rewards = torch.from_numpy(self.stack([e.reward for e in experiences])).float().to(device) - next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device) - dones = torch.from_numpy(self.stack([e.done for e in experiences]).astype(np.uint8)).float().to(device) + states = self.stack([e.state for e in experiences]).float().to(device) + actions = self.stack([e.action for e in experiences]).long().to(device) + rewards = self.stack([e.reward for e in experiences]).float().to(device) + next_states = self.stack([e.next_state for e in experiences]).float().to(device) + dones = self.stack([e.done for e in experiences]).float().to(device) return states, actions, rewards, next_states, dones def stack(self, states): - sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1] - return np.reshape(np.array(states), (len(states), *sub_dims)) + return (torch.stack(states, 0) + if isinstance(states[0], torch.Tensor) + else torch.tensor(states).view(len(states), + *(states[0].shape[1:] + if isinstance(states[0], Iterable) + else [1]))) def __len__(self): return len(self.memory) diff --git a/src/train.py b/src/train.py index 2ffc0f4..1b43800 100644 --- a/src/train.py +++ b/src/train.py @@ -5,18 +5,19 @@ import cv2 import numpy as np +import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool, AgentRenderVariant from tensorboardX import SummaryWriter try: - from .agent import Agent as DQN_Agent + from .agent import Agent as DQN_Agent, device, BATCH_SIZE from .tree_observation import TreeObservation from .observation_utils import normalize_observation, is_collision from .railway_utils import load_precomputed_railways, create_random_railways except: - from agent import Agent as DQN_Agent + from agent import Agent as DQN_Agent, device, BATCH_SIZE from tree_observation import TreeObservation from observation_utils import normalize_observation, is_collision from railway_utils import load_precomputed_railways, create_random_railways @@ -44,8 +45,8 @@ # Training parameters parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") -parser.add_argument("--num-episodes", type=int, default=10**6, help="Number of episodes to train for") -parser.add_argument("--epsilon-decay", type=float, default=0.999, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") +parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") flags = parser.parse_args() @@ -121,6 +122,7 @@ def get_means(x, y, c, s): return (x * 3 + c) / 4, (y * (s - 1) + c) / s + episode = 0 # Main training loop @@ -165,9 +167,15 @@ def get_means(x, y, c, s): # Update replay buffer and train agent for a in range(agent_count): if flags.train and (update_values[a] or done[a] or done['__all__']): - agent.step(a, agent_obs_buffer[a], agent_action_buffer[a], agent_obs[a], done[a], done['__all__'], - is_collision(a), flags.step_reward) - agent_obs_buffer[a] = agent_obs[a].copy() + agent.step(a, + agent_obs_buffer[a], + agent_action_buffer[a], + agent_obs[a], + done[a], + done['__all__'], + is_collision(a), + flags.step_reward) + agent_obs_buffer[a] = agent_obs[a].clone() agent_action_buffer[a] = action_dict[a] if obs[a]: From 6669ed3bb0dd79db45b96d03f7d957d63fda6c0a Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 09:50:51 +0200 Subject: [PATCH 05/75] feat: add baseline model --- src/model.py | 196 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 src/model.py diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..90068ad --- /dev/null +++ b/src/model.py @@ -0,0 +1,196 @@ +import math +import typing + +import numpy as np +import torch + + +@torch.jit.script +def mish(fn_input: torch.Tensor) -> torch.Tensor: + return fn_input * torch.tanh(torch.nn.functional.softplus(fn_input)) + + +class WeightDropLinear(torch.nn.Module): + """ + Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. + + Args: + weight_dropout (float): The probability a weight will be dropped. + """ + + def __init__(self, in_features: int, out_features: int, bias=True, weight_dropout=0.0): + super().__init__() + self.weight_dropout = weight_dropout + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter('bias', None) + + def forward(self, fn_input): + if self.training: + weight = self.weight.bernoulli(p=self.weight_dropout) * self.weight + else: + weight = self.weight + + return torch.nn.functional.linear(fn_input, weight, self.bias) + + def extra_repr(self): + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + + +class SeparableConvolution(torch.nn.Module): + def __init__(self, in_features, out_factor, kernel_size: typing.Union[int, tuple], + padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, + norm=torch.nn.BatchNorm1d): + super(SeparableConvolution, self).__init__() + out_features = in_features * out_factor + self.depthwise_conv = torch.nn.Conv1d(in_features, out_features, + kernel_size, + padding=padding, + groups=in_features, + dilation=dilation, + bias=False) + self.mid_norm = norm(out_features) + self.pointwise_conv = torch.nn.Conv1d(out_features, out_features, 1, bias=False) + self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + + f'dilation={dilation}, padding={padding})') + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + return self.pointwise_conv(self.mid_norm(self.depthwise_conv(fn_input.transpose(1, 2)))).squeeze(-1) + + def __str__(self): + return self.str + + def __repr__(self): + return self.str + + +def try_norm(tensor, norm): + if norm is not None: + tensor = mish(norm(tensor)) + return tensor + + +class Block(torch.nn.Module): + def __init__(self, hidden_size, output_size, bias=False, dropout=0.1, cat=True, init_norm=False, out_norm=True): + super().__init__() + self.residual = hidden_size == output_size + self.cat = cat + + self.init_norm = torch.nn.BatchNorm1d(hidden_size) if init_norm else None + self.linr = WeightDropLinear(hidden_size, output_size, bias=bias, weight_dropout=dropout) + self.out_norm = torch.nn.BatchNorm1d(output_size) if out_norm else None + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + fn_input = try_norm(fn_input, self.init_norm) + out = self.linr(fn_input) + out = try_norm(out, self.out_norm) + if self.cat: + return torch.cat([out, fn_input], 1) + if self.residual: + return out + fn_input + return out + + +# class Residual(torch.nn.Module): +# def __init__(self, m1, m2=None): +# import random +# super().__init__() +# self.m1 = m1 +# self.m2 = copy.deepcopy(m1) if m2 is None else m2 +# self.name = f'residual_{str(int(random.randint(0, 2 ** 32)))}' +# +# def forward(self, fn_input: torch.Tensor) -> torch.Tensor: +# double = fn_input.size(1) > 1 +# if double: +# f0, f1 = fn_input.chunk(2, 1) +# o0 = self.m1(f0) +# o1 = self.m2(f1) +# return torch.cat([o0, o1], 1) + fn_input +# else: +# return self.m1(fn_input) + self.m2(fn_input) + fn_input +# +# def __str__(self): +# return f'{self.__class__.__name__}(ID: {self.name}, M1: {self.m1}, M2: {self.m2})' +# +# def __repr__(self): +# return str(self) +# +# +# def layer_split(target_depth, features, split_depth=3, uneven: typing.Union[bool, int] = False): +# layer_list = [] +# +# if target_depth > split_depth ** 2: +# for _ in range(split_depth): +# layer_list.append(layer_split(target_depth // split_depth, features // 2, split_depth, features % 2)) +# layer_list.append(layer_split(target_depth % split_depth, features // 2, split_depth, features % 2)) +# elif target_depth > split_depth: +# for _ in range(target_depth // split_depth): +# layer_list.append(layer_split(split_depth, features // 2, split_depth, features % 2)) +# layer_list.append(layer_split(target_depth % split_depth, features // 2, split_depth, features % 2)) +# else: +# tmp_features = max(2, features) +# f2, mod = tmp_features // 2, tmp_features % 2 +# layer_list = [Residual(Block(f2 + mod, f2 + mod), Block(f2, f2)) for _ in range(target_depth)] +# layer = torch.nn.Sequential(*layer_list) +# features = max(1, features + uneven) +# layer = Residual(Block(features, features), layer) +# return layer + + +class QNetwork(torch.nn.Module): + def __init__(self, state_size, action_size, hidden_factor=15, depth=4, message_box=16, cat=True): + """ + 11 input features, state_size//11 = item_count + :param state_size: + :param action_size: + :param hidden_factor: + :param depth: + :return: + """ + super(QNetwork, self).__init__() + observations = state_size // 11 + print(f"[DEBUG/MODEL] Using {observations} observations as input") + + out_features = hidden_factor * 11 + + net = torch.nn.ModuleList([torch.nn.Linear(state_size, out_features), + *[Block(out_features + out_features * i * cat, out_features, cat=True, + init_norm=not i) + for i in range(depth)], + Block(out_features + out_features * depth * cat, action_size, + bias=True, + cat=False, + out_norm=False, + init_norm=False)]) + + def init(module: torch.nn.Module): + if hasattr(module, "weight") and hasattr(module.weight, "data"): + if "norm" in module.__class__.__name__.lower() or ( + hasattr(module, "__str__") and "norm" in str(module).lower()): + torch.nn.init.uniform_(module.weight.data, 0.998, 1.002) + else: + torch.nn.init.orthogonal_(module.weight.data) + if hasattr(module, "bias") and hasattr(module.bias, "data"): + torch.nn.init.constant_(module.bias.data, 0) + + net.apply(init) + + parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, net.parameters())) + digits = int(math.log10(parameters)) + number_string = " kMGTPEZY"[digits // 3] + + print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") + + self.net = net + + def forward(self, fn_input: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: + out = fn_input.view(fn_input.size(0), -1) + for module in self.net: + out = module(out) + return out From 2e83da3afc6f07dcd287836d5d7bdccf5c514c93 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 11:14:23 +0200 Subject: [PATCH 06/75] feat: train all agents at the same time --- src/agent.py | 37 ++++++++++++------------ src/model.py | 62 +++++++++++++++++++++++----------------- src/observation_utils.py | 2 +- src/replay_memory.py | 24 +++++++++------- src/train.py | 58 ++++++++++++++++++++----------------- 5 files changed, 100 insertions(+), 83 deletions(-) diff --git a/src/agent.py b/src/agent.py index 5969572..e19cf62 100644 --- a/src/agent.py +++ b/src/agent.py @@ -25,12 +25,11 @@ class Agent: - def __init__(self, state_size, action_size, num_agents, model_depth, hidden_factor, double_dqn=True): + def __init__(self, state_size, action_size, num_agents, model_depth, hidden_factor, kernel_size=7): self.action_size = action_size - self.double_dqn = double_dqn # Q-Network - self.qnetwork_local = QNetwork(state_size, action_size, hidden_factor, model_depth).to(device) + self.qnetwork_local = QNetwork(state_size, action_size, hidden_factor, model_depth, kernel_size).to(device) self.qnetwork_target = copy.deepcopy(self.qnetwork_local) self.optimizer = Optimizer(self.qnetwork_local.parameters(), lr=LR, weight_decay=1e-2) @@ -39,26 +38,27 @@ def __init__(self, state_size, action_size, num_agents, model_depth, hidden_fact self.t_step = 0 def reset(self): - self.finished = [False] * 40 # Up to 40 agents used + self.finished = False # Decide on an action to take in the environment def act(self, state, eps=0.): - state = state.unsqueeze(0).to(device) + agent_count = len(state) + state = torch.stack(state, -1).unsqueeze(0).to(device) self.qnetwork_local.eval() with torch.no_grad(): - action_values = self.qnetwork_local(state)[0] + action_values = self.qnetwork_local(state) # Epsilon-greedy action selection - if random.random() > eps: - return torch.argmax(action_values).item() - else: - return torch.randint(self.action_size, ()).item() + return [torch.argmax(action_values[:, :, i], 1).item() + if random.random() > eps + else torch.randint(self.action_size, ()).item() + for i in range(agent_count)] # Record the results of the agent's action and update the model - def step(self, handle, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): - if not self.finished[handle]: + def step(self, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): + if not self.finished: if agent_done: reward = 1 elif collision: @@ -68,7 +68,7 @@ def step(self, handle, state, action, next_state, agent_done, episode_done, coll # Save experience in replay memory self.memory.push(state, action, reward, next_state, agent_done or episode_done) - self.finished[handle] = agent_done or episode_done + self.finished = episode_done # Perform a gradient update every UPDATE_EVERY time steps self.t_step = (self.t_step + 1) % UPDATE_EVERY @@ -78,15 +78,16 @@ def step(self, handle, state, action, next_state, agent_done, episode_done, coll def learn(self, states, actions, rewards, next_states, dones): self.qnetwork_local.train() + # Get expected Q values from local model - Q_expected = self.qnetwork_local(states) - Q_expected = Q_expected.gather(1, actions) + Q_expected = self.qnetwork_local(states.squeeze(1)) - Q_best_action = self.qnetwork_local(next_states).argmax(1) - Q_targets_next = self.qnetwork_target(next_states).gather(1, Q_best_action.unsqueeze(-1)) + Q_expected = Q_expected.gather(1, actions.unsqueeze(1)) + Q_best_action = self.qnetwork_local(next_states.squeeze(1)).argmax(1) + Q_targets_next = self.qnetwork_target(next_states.squeeze(1)).gather(1, Q_best_action.unsqueeze(1)) # Compute Q targets for current states - Q_targets = rewards + GAMMA * Q_targets_next * (1 - dones) + Q_targets = rewards.unsqueeze(-1) + GAMMA * Q_targets_next * (1 - dones.unsqueeze(-1)) # Compute loss and perform a gradient step self.optimizer.zero_grad() diff --git a/src/model.py b/src/model.py index 90068ad..89430c8 100644 --- a/src/model.py +++ b/src/model.py @@ -10,7 +10,7 @@ def mish(fn_input: torch.Tensor) -> torch.Tensor: return fn_input * torch.tanh(torch.nn.functional.softplus(fn_input)) -class WeightDropLinear(torch.nn.Module): +class WeightDropConv(torch.nn.Module): """ Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. @@ -18,16 +18,22 @@ class WeightDropLinear(torch.nn.Module): weight_dropout (float): The probability a weight will be dropped. """ - def __init__(self, in_features: int, out_features: int, bias=True, weight_dropout=0.0): + def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True, weight_dropout=0.1, groups=1, + padding=0, dilation=1): super().__init__() self.weight_dropout = weight_dropout self.in_features = in_features self.out_features = out_features - self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + if in_features % groups != 0: + print(f"[ERROR] Unable to get weight for in={in_features},groups={groups}. Make sure they are divisible.") + if out_features % groups != 0: + print(f"[ERROR] Unable to get weight for out={out_features},groups={groups}. Make sure they are divisible.") + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features // groups, kernel_size)) if bias: self.bias = torch.nn.Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) + self._kwargs = {'bias': self.bias, 'padding': padding, 'dilation': dilation, 'groups': groups} def forward(self, fn_input): if self.training: @@ -35,33 +41,30 @@ def forward(self, fn_input): else: weight = self.weight - return torch.nn.functional.linear(fn_input, weight, self.bias) + return torch.nn.functional.conv1d(fn_input, weight, **self._kwargs) def extra_repr(self): - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, self.out_features, self.bias is not None - ) + return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) class SeparableConvolution(torch.nn.Module): - def __init__(self, in_features, out_factor, kernel_size: typing.Union[int, tuple], + def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, - norm=torch.nn.BatchNorm1d): + norm=torch.nn.BatchNorm1d, bias=False): super(SeparableConvolution, self).__init__() - out_features = in_features * out_factor - self.depthwise_conv = torch.nn.Conv1d(in_features, out_features, - kernel_size, - padding=padding, - groups=in_features, - dilation=dilation, - bias=False) - self.mid_norm = norm(out_features) - self.pointwise_conv = torch.nn.Conv1d(out_features, out_features, 1, bias=False) + self.depthwise_conv = WeightDropConv(in_features, in_features, + kernel_size, + padding=padding, + groups=in_features, + dilation=dilation, + bias=False) + self.mid_norm = norm(in_features) + self.pointwise_conv = WeightDropConv(in_features, out_features, 1, bias=bias) self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + f'dilation={dilation}, padding={padding})') def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - return self.pointwise_conv(self.mid_norm(self.depthwise_conv(fn_input.transpose(1, 2)))).squeeze(-1) + return self.pointwise_conv(self.mid_norm(self.depthwise_conv(fn_input))) def __str__(self): return self.str @@ -77,13 +80,14 @@ def try_norm(tensor, norm): class Block(torch.nn.Module): - def __init__(self, hidden_size, output_size, bias=False, dropout=0.1, cat=True, init_norm=False, out_norm=True): + def __init__(self, hidden_size, output_size, bias=False, cat=True, init_norm=False, out_norm=True, + kernel_size=7): super().__init__() self.residual = hidden_size == output_size self.cat = cat self.init_norm = torch.nn.BatchNorm1d(hidden_size) if init_norm else None - self.linr = WeightDropLinear(hidden_size, output_size, bias=bias, weight_dropout=dropout) + self.linr = SeparableConvolution(hidden_size, output_size, kernel_size, padding=kernel_size // 2, bias=bias) self.out_norm = torch.nn.BatchNorm1d(output_size) if out_norm else None def forward(self, fn_input: torch.Tensor) -> torch.Tensor: @@ -144,7 +148,7 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: class QNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_factor=15, depth=4, message_box=16, cat=True): + def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, cat=True): """ 11 input features, state_size//11 = item_count :param state_size: @@ -159,15 +163,19 @@ def __init__(self, state_size, action_size, hidden_factor=15, depth=4, message_b out_features = hidden_factor * 11 - net = torch.nn.ModuleList([torch.nn.Linear(state_size, out_features), - *[Block(out_features + out_features * i * cat, out_features, cat=True, - init_norm=not i) + net = torch.nn.ModuleList([torch.nn.Conv1d(state_size, out_features, 1), + *[Block(out_features + out_features * i * cat, + out_features, + cat=True, + init_norm=not i, + kernel_size=kernel_size) for i in range(depth)], Block(out_features + out_features * depth * cat, action_size, bias=True, cat=False, out_norm=False, - init_norm=False)]) + init_norm=False, + kernel_size=kernel_size)]) def init(module: torch.nn.Module): if hasattr(module, "weight") and hasattr(module.weight, "data"): @@ -190,7 +198,7 @@ def init(module: torch.nn.Module): self.net = net def forward(self, fn_input: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: - out = fn_input.view(fn_input.size(0), -1) + out = fn_input for module in self.net: out = module(out) return out diff --git a/src/observation_utils.py b/src/observation_utils.py index b146068..61ea7c6 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -77,4 +77,4 @@ def normalize_observation(tree, max_depth, zero_center=True): data[:, :6].sub_(data[:, :6].mean()) data[:, 7:].sub_(data[:, 7:].mean()) - return data.view(-1) + return data.flatten() diff --git a/src/replay_memory.py b/src/replay_memory.py index 461e870..e4667ac 100644 --- a/src/replay_memory.py +++ b/src/replay_memory.py @@ -1,7 +1,6 @@ import random from collections import namedtuple, deque, Iterable -import numpy as np import torch Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) @@ -28,7 +27,11 @@ def __init__(self, buffer_size): self.memory = deque(maxlen=buffer_size) def push(self, state, action, reward, next_state, done): - self.memory.append(Transition(state.unsqueeze(0), action, reward, next_state.unsqueeze(0), done)) + self.memory.append(Transition(torch.stack(state, -1).unsqueeze(0), + action, + reward, + torch.stack(next_state, -1).unsqueeze(0), + done)) def push_episode(self, episode): for step in episode.memory: @@ -41,17 +44,18 @@ def sample(self, batch_size, device): actions = self.stack([e.action for e in experiences]).long().to(device) rewards = self.stack([e.reward for e in experiences]).float().to(device) next_states = self.stack([e.next_state for e in experiences]).float().to(device) - dones = self.stack([e.done for e in experiences]).float().to(device) + dones = self.stack([list(e.done.values()) for e in experiences]).float().to(device) return states, actions, rewards, next_states, dones - def stack(self, states): - return (torch.stack(states, 0) - if isinstance(states[0], torch.Tensor) - else torch.tensor(states).view(len(states), - *(states[0].shape[1:] - if isinstance(states[0], Iterable) - else [1]))) + def stack(self, states, dim=0): + if isinstance(states[0], Iterable): + if isinstance(states[0][0], list): + return torch.stack([self.stack(st, -1) for st in states], dim) + if isinstance(states[0], torch.Tensor): + return torch.stack(states, 0) + return torch.tensor(states) + return torch.tensor(states).view(len(states), 1) def __len__(self): return len(self.memory) diff --git a/src/train.py b/src/train.py index 1b43800..30fb15b 100644 --- a/src/train.py +++ b/src/train.py @@ -5,7 +5,6 @@ import cv2 import numpy as np -import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool, AgentRenderVariant @@ -42,6 +41,7 @@ parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--model-depth", type=int, default=4, help="Depth of the observation tree") parser.add_argument("--hidden-factor", type=int, default=15, help="Depth of the observation tree") +parser.add_argument("--kernel-size", type=int, default=7, help="Depth of the observation tree") # Training parameters parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") @@ -62,7 +62,7 @@ state_size = num_nodes * num_features_per_node action_size = 5 # Load an RL agent and initialize it from checkpoint if necessary -agent = DQN_Agent(state_size, action_size, flags.num_agents, flags.model_depth, flags.hidden_factor) +agent = DQN_Agent(state_size, action_size, flags.num_agents, flags.model_depth, flags.hidden_factor, flags.kernel_size) if flags.load_model: start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) else: @@ -122,7 +122,6 @@ def get_means(x, y, c, s): return (x * 3 + c) / 4, (y * (s - 1) + c) / s - episode = 0 # Main training loop @@ -144,14 +143,18 @@ def get_means(x, y, c, s): update_values = [False] * agent_count action_dict = {} - for a in range(agent_count): - if info['action_required'][a]: - action_dict[a] = agent.act(agent_obs[a], eps=eps) + if any(info['action_required']): + ret_action = agent.act(agent_obs, eps=eps) + else: + ret_action = update_values + for idx, act in enumerate(ret_action): + if info['action_required'][idx]: + action_dict[idx] = act # action_dict[a] = np.random.randint(5) - update_values[a] = True + update_values[idx] = True steps_taken += 1 else: - action_dict[a] = 0 + action_dict[idx] = 0 # Environment step obs, rewards, done, info = env.step(action_dict) @@ -165,29 +168,30 @@ def get_means(x, y, c, s): # done['__all__'] = True # Update replay buffer and train agent - for a in range(agent_count): - if flags.train and (update_values[a] or done[a] or done['__all__']): - agent.step(a, - agent_obs_buffer[a], - agent_action_buffer[a], - agent_obs[a], - done[a], - done['__all__'], - is_collision(a), - flags.step_reward) - agent_obs_buffer[a] = agent_obs[a].clone() - agent_action_buffer[a] = action_dict[a] + if flags.train and (any(update_values) or any(done) or done['__all__']): + agent.step(agent_obs_buffer, + agent_action_buffer, + agent_obs, + done, + done['__all__'], + [is_collision(a) for a in range(agent_count)], + flags.step_reward) + agent_obs_buffer = [o.clone() for o in agent_obs] + for key, value in action_dict.items(): + agent_action_buffer[key] = value + for a in range(agent_count): if obs[a]: agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=True) - # Render - # if flags.render_interval and episode % flags.render_interval == 0: - # if collision and all(agent.position for agent in env.agents): - # render() - # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) - # break - if done['__all__']: break + # Render + # if flags.render_interval and episode % flags.render_interval == 0: + # if collision and all(agent.position for agent in env.agents): + # render() + # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) + # break + if done['__all__']: + break # Epsilon decay if flags.train: From 9fde2584c3738fd723744bffaac713b62e9f7b21 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Jul 2020 11:32:19 +0200 Subject: [PATCH 07/75] feat: add attention --- src/agent.py | 11 ++++++--- src/model.py | 53 ++++++++++++++++++++++++++++++++++++++++---- src/railway_utils.py | 2 +- src/train.py | 11 +++++++-- 4 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/agent.py b/src/agent.py index e19cf62..59e5b84 100644 --- a/src/agent.py +++ b/src/agent.py @@ -19,17 +19,22 @@ GAMMA = 0.998 TAU = 1e-3 LR = 3e-5 -UPDATE_EVERY = 1 +UPDATE_EVERY = 160 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class Agent: - def __init__(self, state_size, action_size, num_agents, model_depth, hidden_factor, kernel_size=7): + def __init__(self, state_size, action_size, num_agents, model_depth, hidden_factor, kernel_size, squeeze_heads): self.action_size = action_size # Q-Network - self.qnetwork_local = QNetwork(state_size, action_size, hidden_factor, model_depth, kernel_size).to(device) + self.qnetwork_local = QNetwork(state_size, + action_size, + hidden_factor, + model_depth, + kernel_size, + squeeze_heads).to(device) self.qnetwork_target = copy.deepcopy(self.qnetwork_local) self.optimizer = Optimizer(self.qnetwork_local.parameters(), lr=LR, weight_decay=1e-2) diff --git a/src/model.py b/src/model.py index 89430c8..e1a21fd 100644 --- a/src/model.py +++ b/src/model.py @@ -81,7 +81,7 @@ def try_norm(tensor, norm): class Block(torch.nn.Module): def __init__(self, hidden_size, output_size, bias=False, cat=True, init_norm=False, out_norm=True, - kernel_size=7): + kernel_size=7, squeeze_heads=4): super().__init__() self.residual = hidden_size == output_size self.cat = cat @@ -90,14 +90,57 @@ def __init__(self, hidden_size, output_size, bias=False, cat=True, init_norm=Fal self.linr = SeparableConvolution(hidden_size, output_size, kernel_size, padding=kernel_size // 2, bias=bias) self.out_norm = torch.nn.BatchNorm1d(output_size) if out_norm else None + self.use_squeeze_attention = squeeze_heads > 0 + + if self.use_squeeze_attention: + self.squeeze_heads = squeeze_heads + self.exc_input_norm = torch.nn.BatchNorm1d(squeeze_heads) + self.expert_ranker = torch.nn.Linear(output_size, squeeze_heads, False) + self.excitation_conv = SeparableConvolution(output_size, squeeze_heads, kernel_size, + padding=kernel_size // 2) + self.linear_in_norm = torch.nn.BatchNorm1d(output_size * squeeze_heads) + self.linear0 = torch.nn.Linear(output_size * squeeze_heads, output_size, False) + self.exc_norm = torch.nn.BatchNorm1d(output_size) + self.linear1 = torch.nn.Linear(output_size, output_size) + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + batch = fn_input.size(0) fn_input = try_norm(fn_input, self.init_norm) out = self.linr(fn_input) out = try_norm(out, self.out_norm) + + if self.use_squeeze_attention: + exc = self.excitation_conv(out) + exc = torch.nn.functional.softmax(exc, 2) + exc = exc.unsqueeze(-1).transpose(1, -1) + exc = (out.unsqueeze(-1) * exc).sum(2) + + # Rank experts (heads) + hds = exc.view(batch, self.squeeze_heads, -1) + exc = self.exc_input_norm(hds) + exc = self.expert_ranker(mish(exc)) + exc = exc.softmax(-1) + exc = exc.bmm(hds) + exc = exc.view(batch, -1, 1) + + # Fully-connected block + nrm = self.linear_in_norm(exc).squeeze(-1) + nrm = self.linear0(nrm).unsqueeze(-1) + nrm = self.exc_norm(nrm) + act = mish(nrm.squeeze(-1)) + exc = self.linear1(act).tanh() + exc = exc.unsqueeze(-1) + exc = exc.expand_as(out) + + # Merge + out = out * exc + if self.cat: return torch.cat([out, fn_input], 1) + if self.residual: return out + fn_input + return out @@ -148,7 +191,7 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: class QNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, cat=True): + def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=True): """ 11 input features, state_size//11 = item_count :param state_size: @@ -168,14 +211,16 @@ def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_si out_features, cat=True, init_norm=not i, - kernel_size=kernel_size) + kernel_size=kernel_size, + squeeze_heads=squeeze_heads) for i in range(depth)], Block(out_features + out_features * depth * cat, action_size, bias=True, cat=False, out_norm=False, init_norm=False, - kernel_size=kernel_size)]) + kernel_size=kernel_size, + squeeze_heads=squeeze_heads)]) def init(module: torch.nn.Module): if hasattr(module, "weight") and hasattr(module.weight, "data"): diff --git a/src/railway_utils.py b/src/railway_utils.py index a533fa2..f5d12c9 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -43,7 +43,7 @@ def __call__(self, *args, **kwargs): # Helper function to load in precomputed railway networks def load_precomputed_railways(project_root, start_index): prefix = os.path.join(project_root, 'railroads') - suffix = f'_3x30x30.pkl' + suffix = f'_sum.pkl' sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) print(f"Working on {len(rail)} tracks") diff --git a/src/train.py b/src/train.py index 30fb15b..42b2ca6 100644 --- a/src/train.py +++ b/src/train.py @@ -41,7 +41,8 @@ parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--model-depth", type=int, default=4, help="Depth of the observation tree") parser.add_argument("--hidden-factor", type=int, default=15, help="Depth of the observation tree") -parser.add_argument("--kernel-size", type=int, default=7, help="Depth of the observation tree") +parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") # Training parameters parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") @@ -62,7 +63,13 @@ state_size = num_nodes * num_features_per_node action_size = 5 # Load an RL agent and initialize it from checkpoint if necessary -agent = DQN_Agent(state_size, action_size, flags.num_agents, flags.model_depth, flags.hidden_factor, flags.kernel_size) +agent = DQN_Agent(state_size, + action_size, + flags.num_agents, + flags.model_depth, + flags.hidden_factor, + flags.kernel_size, + flags.squeeze_heads) if flags.load_model: start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) else: From 6c6664d2cc7f247af774a34fb6b5a1f208f6e612 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Thu, 9 Jul 2020 11:59:33 +0200 Subject: [PATCH 08/75] feat: improve speed --- railroads2/cat.py | 29 ---- src/agent.py | 55 +++++--- src/model.py | 315 ++++++++++++++++++++++++------------------- src/railway_utils.py | 4 +- src/replay_memory.py | 3 +- src/train.py | 16 ++- 6 files changed, 229 insertions(+), 193 deletions(-) delete mode 100644 railroads2/cat.py diff --git a/railroads2/cat.py b/railroads2/cat.py deleted file mode 100644 index abbf632..0000000 --- a/railroads2/cat.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import pickle -import random - - -def main(base, out_name): - files = sorted([i for i in os.listdir() if i.startswith(base) and not i.endswith('.bak') and 'sum' not in i]) - print(f'Concatenating {", ".join(files)}') - - out = [] - - for name in files: - with open(name, 'rb') as f: - try: - out.extend(pickle.load(f)) - except Exception as e: - print(f'Caught {e} while processing {name}') - - name = out_name+'sum.pkl' - random.seed(0) - random.shuffle(out) - with open(name, 'wb') as f: - pickle.dump(out, f) - print(f'Dumped {len(out)} items from {len(files)} sources to {name}') - -if __name__ == '__main__': - main('rail_networks_', 'schedules_') - main('schedules_', 'rail_networks_') - diff --git a/src/agent.py b/src/agent.py index 59e5b84..dbf4aaf 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,41 +1,50 @@ -import copy import pickle import random import torch -import torch.nn.functional as F from torch_optimizer import Yogi as Optimizer try: - from .model import QNetwork + from .model import QNetwork, ConvNetwork from .replay_memory import ReplayBuffer except: - from model import QNetwork + from model import QNetwork, ConvNetwork from replay_memory import ReplayBuffer import os BUFFER_SIZE = 500_000 -BATCH_SIZE = 512 +BATCH_SIZE = 256 GAMMA = 0.998 TAU = 1e-3 -LR = 3e-5 -UPDATE_EVERY = 160 +LR = 2e-4 +UPDATE_EVERY = 1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class Agent: - def __init__(self, state_size, action_size, num_agents, model_depth, hidden_factor, kernel_size, squeeze_heads): + def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, + use_global=False): self.action_size = action_size # Q-Network - self.qnetwork_local = QNetwork(state_size, + if use_global: + network = ConvNetwork + else: + network = QNetwork + self.qnetwork_local = network(state_size, + action_size, + hidden_factor, + model_depth, + kernel_size, + squeeze_heads).to(device) + self.qnetwork_target = network(state_size, action_size, hidden_factor, model_depth, kernel_size, - squeeze_heads).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + squeeze_heads, + debug=False).to(device) self.optimizer = Optimizer(self.qnetwork_local.parameters(), lr=LR, weight_decay=1e-2) # Replay memory @@ -60,6 +69,20 @@ def act(self, state, eps=0.): else torch.randint(self.action_size, ()).item() for i in range(agent_count)] + def multi_act(self, states, eps=0.): + agent_count = len(states[0]) + state = torch.stack([torch.stack(state, -1) for state in states], 0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + + # Epsilon-greedy action selection + return [[torch.argmax(act[:, :, i], 1).item() + if random.random() > eps + else torch.randint(self.action_size, ()).item() + for i in range(agent_count)] + for act in action_values.__iter__()] + # Record the results of the agent's action and update the model def step(self, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): @@ -76,14 +99,13 @@ def step(self, state, action, next_state, agent_done, episode_done, collision, s self.finished = episode_done # Perform a gradient update every UPDATE_EVERY time steps - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 20: + # self.t_step = (self.t_step + 1) % UPDATE_EVERY + if len(self.memory) > BATCH_SIZE * 20: self.learn(*self.memory.sample(BATCH_SIZE, device)) def learn(self, states, actions, rewards, next_states, dones): self.qnetwork_local.train() - # Get expected Q values from local model Q_expected = self.qnetwork_local(states.squeeze(1)) @@ -91,12 +113,9 @@ def learn(self, states, actions, rewards, next_states, dones): Q_best_action = self.qnetwork_local(next_states.squeeze(1)).argmax(1) Q_targets_next = self.qnetwork_target(next_states.squeeze(1)).gather(1, Q_best_action.unsqueeze(1)) - # Compute Q targets for current states - Q_targets = rewards.unsqueeze(-1) + GAMMA * Q_targets_next * (1 - dones.unsqueeze(-1)) - # Compute loss and perform a gradient step self.optimizer.zero_grad() - loss = F.mse_loss(Q_expected, Q_targets) + loss = (rewards.unsqueeze(-1) + GAMMA * Q_targets_next * (1 - dones.unsqueeze(-2)) - Q_expected).square().mean() loss.backward() self.optimizer.step() diff --git a/src/model.py b/src/model.py index e1a21fd..08288cc 100644 --- a/src/model.py +++ b/src/model.py @@ -1,7 +1,5 @@ -import math import typing -import numpy as np import torch @@ -10,6 +8,11 @@ def mish(fn_input: torch.Tensor) -> torch.Tensor: return fn_input * torch.tanh(torch.nn.functional.softplus(fn_input)) +class Mish(torch.nn.Module): + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + return mish(fn_input) + + class WeightDropConv(torch.nn.Module): """ Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. @@ -19,7 +22,7 @@ class WeightDropConv(torch.nn.Module): """ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True, weight_dropout=0.1, groups=1, - padding=0, dilation=1): + padding=0, dilation=1, function=torch.nn.functional.conv1d, stride=1): super().__init__() self.weight_dropout = weight_dropout self.in_features = in_features @@ -33,7 +36,8 @@ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True self.bias = torch.nn.Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) - self._kwargs = {'bias': self.bias, 'padding': padding, 'dilation': dilation, 'groups': groups} + self._kwargs = {'bias': self.bias, 'padding': padding, 'dilation': dilation, 'groups': groups, 'stride': stride} + self._function = function def forward(self, fn_input): if self.training: @@ -41,7 +45,7 @@ def forward(self, fn_input): else: weight = self.weight - return torch.nn.functional.conv1d(fn_input, weight, **self._kwargs) + return self._function(fn_input, weight, **self._kwargs) def extra_repr(self): return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) @@ -50,21 +54,29 @@ def extra_repr(self): class SeparableConvolution(torch.nn.Module): def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, - norm=torch.nn.BatchNorm1d, bias=False): + bias=False, dim=1, stride=1): super(SeparableConvolution, self).__init__() - self.depthwise_conv = WeightDropConv(in_features, in_features, - kernel_size, - padding=padding, - groups=in_features, - dilation=dilation, - bias=False) - self.mid_norm = norm(in_features) - self.pointwise_conv = WeightDropConv(in_features, out_features, 1, bias=bias) + self.depthwise = kernel_size > 1 + function = getattr(torch.nn.functional, f'conv{dim}d') + norm = getattr(torch.nn, f'BatchNorm{dim}d') + if self.depthwise: + self.depthwise_conv = WeightDropConv(in_features, in_features, + kernel_size, + padding=padding, + groups=in_features, + dilation=dilation, + bias=False, + function=function, + stride=stride) + self.mid_norm = norm(in_features) + self.pointwise_conv = WeightDropConv(in_features, out_features, 1, bias=bias, function=function) self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + f'dilation={dilation}, padding={padding})') def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - return self.pointwise_conv(self.mid_norm(self.depthwise_conv(fn_input))) + if self.depthwise: + fn_input = self.mid_norm(self.depthwise_conv(fn_input)) + return self.pointwise_conv(fn_input) def __str__(self): return self.str @@ -79,6 +91,43 @@ def try_norm(tensor, norm): return tensor +def make_excite(conv, rank_norm, ranker, linear_norm, linear0, excite_norm, linear1): + @torch.jit.script + def excite(out): + batch = out.size(0) + + exc = conv(out) + + squeeze_heads = exc.size(1) + + exc = torch.nn.functional.softmax(exc, 2) + exc = exc.unsqueeze(-1).transpose(1, -1) + exc = (out.unsqueeze(-1) * exc).sum(2) + + # Rank experts (heads) + hds = exc.view(batch, squeeze_heads, -1) + exc = rank_norm(hds) + exc = ranker(mish(exc)) + exc = exc.softmax(-1) + exc = exc.bmm(hds) + exc = exc.view(batch, -1, 1) + + # Fully-connected block + nrm = linear_norm(exc).squeeze(-1) + nrm = linear0(nrm).unsqueeze(-1) + nrm = excite_norm(nrm) + act = mish(nrm.squeeze(-1)) + exc = linear1(act).tanh() + exc = exc.unsqueeze(-1) + exc = exc.expand_as(out) + + # Merge + out = out * exc + return out + + return excite + + class Block(torch.nn.Module): def __init__(self, hidden_size, output_size, bias=False, cat=True, init_norm=False, out_norm=True, kernel_size=7, squeeze_heads=4): @@ -102,38 +151,21 @@ def __init__(self, hidden_size, output_size, bias=False, cat=True, init_norm=Fal self.linear0 = torch.nn.Linear(output_size * squeeze_heads, output_size, False) self.exc_norm = torch.nn.BatchNorm1d(output_size) self.linear1 = torch.nn.Linear(output_size, output_size) + self.excite = make_excite(self.excitation_conv, + self.exc_input_norm, + self.expert_ranker, + self.linear_in_norm, + self.linear0, + self.exc_norm, + self.linear1) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - batch = fn_input.size(0) fn_input = try_norm(fn_input, self.init_norm) out = self.linr(fn_input) out = try_norm(out, self.out_norm) if self.use_squeeze_attention: - exc = self.excitation_conv(out) - exc = torch.nn.functional.softmax(exc, 2) - exc = exc.unsqueeze(-1).transpose(1, -1) - exc = (out.unsqueeze(-1) * exc).sum(2) - - # Rank experts (heads) - hds = exc.view(batch, self.squeeze_heads, -1) - exc = self.exc_input_norm(hds) - exc = self.expert_ranker(mish(exc)) - exc = exc.softmax(-1) - exc = exc.bmm(hds) - exc = exc.view(batch, -1, 1) - - # Fully-connected block - nrm = self.linear_in_norm(exc).squeeze(-1) - nrm = self.linear0(nrm).unsqueeze(-1) - nrm = self.exc_norm(nrm) - act = mish(nrm.squeeze(-1)) - exc = self.linear1(act).tanh() - exc = exc.unsqueeze(-1) - exc = exc.expand_as(out) - - # Merge - out = out * exc + out = self.excite(out) if self.cat: return torch.cat([out, fn_input], 1) @@ -144,106 +176,115 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: return out -# class Residual(torch.nn.Module): -# def __init__(self, m1, m2=None): -# import random -# super().__init__() -# self.m1 = m1 -# self.m2 = copy.deepcopy(m1) if m2 is None else m2 -# self.name = f'residual_{str(int(random.randint(0, 2 ** 32)))}' +class BasicBlock(torch.nn.Module): + def __init__(self, in_features, out_features, stride, init_norm=False): + super(BasicBlock, self).__init__() + self.init_norm = torch.nn.BatchNorm2d(out_features) if init_norm else None + self.init_conv = SeparableConvolution(in_features, out_features, 3, 1, stride=stride, dim=2) + self.mid_norm = torch.nn.BatchNorm2d(out_features) + self.end_conv = SeparableConvolution(in_features, out_features, 3, 1, dim=2) + self.shortcut = (None + if stride == 1 and in_features == out_features + else SeparableConvolution(in_features, out_features, 3, 1, stride=stride, dim=2)) + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + out = self.init_conv(fn_input if self.init_norm is None else mish(self.init_norm(fn_input))) + out = mish(self.mid_norm(out)) + out = self.end_conv(out) + if self.shortcut is not None: + fn_input = self.shortcut(fn_input) + out = out + fn_input + return out + + +class ConvNetwork(torch.nn.Module): + def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=True, + debug=True): + super(ConvNetwork, self).__init__() + hidden_size = 11 * hidden_factor + self.net = torch.nn.ModuleList([BasicBlock(state_size, hidden_size, 1), + *[BasicBlock(hidden_size, hidden_size, 2 - i % 2, True) + for i in range(depth)]]) + self.init_norm = torch.nn.BatchNorm1d(hidden_size) + self.linear0 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.mid_norm = torch.nn.BatchNorm1d(hidden_size) + self.linear1 = torch.nn.Linear(hidden_size, action_size) + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + out = fn_input + for module in self.net: + out = module(out) + out = out.mean((2, 3)) + out = self.linear1(mish(self.mid_norm(self.linear0(mish(self.init_norm(out)))))) + return out + + +# class QNetwork(torch.nn.Module): +# def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, +# debug=True): +# """ +# 11 input features, state_size//11 = item_count +# :param state_size: +# :param action_size: +# :param hidden_factor: +# :param depth: +# :return: +# """ +# super(QNetwork, self).__init__() +# observations = state_size // 11 +# if debug: +# print(f"[DEBUG/MODEL] Using {observations} observations as input") # -# def forward(self, fn_input: torch.Tensor) -> torch.Tensor: -# double = fn_input.size(1) > 1 -# if double: -# f0, f1 = fn_input.chunk(2, 1) -# o0 = self.m1(f0) -# o1 = self.m2(f1) -# return torch.cat([o0, o1], 1) + fn_input -# else: -# return self.m1(fn_input) + self.m2(fn_input) + fn_input +# out_features = hidden_factor * 11 # -# def __str__(self): -# return f'{self.__class__.__name__}(ID: {self.name}, M1: {self.m1}, M2: {self.m2})' +# net = torch.nn.ModuleList([torch.nn.Conv1d(state_size, out_features, 1), +# *[Block(out_features + out_features * i * cat, +# out_features, +# cat=cat, +# init_norm=not i, +# kernel_size=kernel_size, +# squeeze_heads=squeeze_heads) +# for i in range(depth)], +# Block(out_features + out_features * depth * cat, action_size, +# bias=True, +# cat=False, +# out_norm=False, +# init_norm=False, +# kernel_size=kernel_size, +# squeeze_heads=squeeze_heads)]) # -# def __repr__(self): -# return str(self) +# def init(module: torch.nn.Module): +# if hasattr(module, "weight") and hasattr(module.weight, "data"): +# if "norm" in module.__class__.__name__.lower() or ( +# hasattr(module, "__str__") and "norm" in str(module).lower()): +# torch.nn.init.uniform_(module.weight.data, 0.998, 1.002) +# else: +# torch.nn.init.orthogonal_(module.weight.data) +# if hasattr(module, "bias") and hasattr(module.bias, "data"): +# torch.nn.init.constant_(module.bias.data, 0) # +# net.apply(init) # -# def layer_split(target_depth, features, split_depth=3, uneven: typing.Union[bool, int] = False): -# layer_list = [] +# if debug: +# parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, net.parameters())) +# digits = int(math.log10(parameters)) +# number_string = " kMGTPEZY"[digits // 3] # -# if target_depth > split_depth ** 2: -# for _ in range(split_depth): -# layer_list.append(layer_split(target_depth // split_depth, features // 2, split_depth, features % 2)) -# layer_list.append(layer_split(target_depth % split_depth, features // 2, split_depth, features % 2)) -# elif target_depth > split_depth: -# for _ in range(target_depth // split_depth): -# layer_list.append(layer_split(split_depth, features // 2, split_depth, features % 2)) -# layer_list.append(layer_split(target_depth % split_depth, features // 2, split_depth, features % 2)) -# else: -# tmp_features = max(2, features) -# f2, mod = tmp_features // 2, tmp_features % 2 -# layer_list = [Residual(Block(f2 + mod, f2 + mod), Block(f2, f2)) for _ in range(target_depth)] -# layer = torch.nn.Sequential(*layer_list) -# features = max(1, features + uneven) -# layer = Residual(Block(features, features), layer) -# return layer - - -class QNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=True): - """ - 11 input features, state_size//11 = item_count - :param state_size: - :param action_size: - :param hidden_factor: - :param depth: - :return: - """ - super(QNetwork, self).__init__() - observations = state_size // 11 - print(f"[DEBUG/MODEL] Using {observations} observations as input") - - out_features = hidden_factor * 11 - - net = torch.nn.ModuleList([torch.nn.Conv1d(state_size, out_features, 1), - *[Block(out_features + out_features * i * cat, - out_features, - cat=True, - init_norm=not i, - kernel_size=kernel_size, - squeeze_heads=squeeze_heads) - for i in range(depth)], - Block(out_features + out_features * depth * cat, action_size, - bias=True, - cat=False, - out_norm=False, - init_norm=False, - kernel_size=kernel_size, - squeeze_heads=squeeze_heads)]) - - def init(module: torch.nn.Module): - if hasattr(module, "weight") and hasattr(module.weight, "data"): - if "norm" in module.__class__.__name__.lower() or ( - hasattr(module, "__str__") and "norm" in str(module).lower()): - torch.nn.init.uniform_(module.weight.data, 0.998, 1.002) - else: - torch.nn.init.orthogonal_(module.weight.data) - if hasattr(module, "bias") and hasattr(module.bias, "data"): - torch.nn.init.constant_(module.bias.data, 0) - - net.apply(init) - - parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, net.parameters())) - digits = int(math.log10(parameters)) - number_string = " kMGTPEZY"[digits // 3] - - print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") - - self.net = net - - def forward(self, fn_input: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: - out = fn_input - for module in self.net: - out = module(out) - return out +# print( +# f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") +# +# self.net = net +# +# def forward(self, fn_input: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: +# out = fn_input +# for module in self.net: +# out = module(out) +# return out + +def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, + debug=True): + model = torch.nn.Sequential(torch.nn.Conv1d(state_size, 11 * hidden_factor, 1, bias=False), + torch.nn.BatchNorm1d(11 * hidden_factor), + Mish(), + torch.nn.Conv1d(11*hidden_factor, action_size, 1)) + return model diff --git a/src/railway_utils.py b/src/railway_utils.py index f5d12c9..b17c025 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -44,8 +44,8 @@ def __call__(self, *args, **kwargs): def load_precomputed_railways(project_root, start_index): prefix = os.path.join(project_root, 'railroads') suffix = f'_sum.pkl' - sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) - rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) print(f"Working on {len(rail)} tracks") return rail, sched diff --git a/src/replay_memory.py b/src/replay_memory.py index e4667ac..0fbe503 100644 --- a/src/replay_memory.py +++ b/src/replay_memory.py @@ -44,7 +44,8 @@ def sample(self, batch_size, device): actions = self.stack([e.action for e in experiences]).long().to(device) rewards = self.stack([e.reward for e in experiences]).float().to(device) next_states = self.stack([e.next_state for e in experiences]).float().to(device) - dones = self.stack([list(e.done.values()) for e in experiences]).float().to(device) + dones = self.stack([[v for k, v in e.done.items() if not hasattr(k, 'startswith') or not k.startswith('_')] + for e in experiences]).float().to(device) return states, actions, rewards, next_states, dones diff --git a/src/train.py b/src/train.py index 42b2ca6..77b7aee 100644 --- a/src/train.py +++ b/src/train.py @@ -7,6 +7,7 @@ import numpy as np from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.rail_env import RailEnv +from flatland.envs.observations import GlobalObsForRailEnv from flatland.utils.rendertools import RenderTool, AgentRenderVariant from tensorboardX import SummaryWriter @@ -38,9 +39,9 @@ parser.add_argument("--grid-width", type=int, default=50, help="Number of columns in the environment grid") parser.add_argument("--grid-height", type=int, default=50, help="Number of rows in the environment grid") parser.add_argument("--num-agents", type=int, default=5, help="Number of agents in each episode") -parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") -parser.add_argument("--model-depth", type=int, default=4, help="Depth of the observation tree") -parser.add_argument("--hidden-factor", type=int, default=15, help="Depth of the observation tree") +parser.add_argument("--tree-depth", type=int, default=2, help="Depth of the observation tree") +parser.add_argument("--model-depth", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--hidden-factor", type=int, default=5, help="Depth of the observation tree") parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") @@ -49,6 +50,7 @@ parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") +parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") flags = parser.parse_args() @@ -65,11 +67,11 @@ # Load an RL agent and initialize it from checkpoint if necessary agent = DQN_Agent(state_size, action_size, - flags.num_agents, flags.model_depth, flags.hidden_factor, flags.kernel_size, - flags.squeeze_heads) + flags.squeeze_heads, + flags.global_environment) if flags.load_model: start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) else: @@ -85,7 +87,9 @@ rail_generator=rail_generator, schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params(MalfunctionParameters(1 / 8000, 15, 50)), - obs_builder_object=TreeObservation(max_depth=flags.tree_depth) + obs_builder_object=(GlobalObsForRailEnv() + if flags.global_environment + else TreeObservation(max_depth=flags.tree_depth)) ) # After training we want to render the results so we also load a renderer From e0ac945ef77760f2f6775df5ecfdf799c41599ce Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Thu, 9 Jul 2020 16:07:22 +0200 Subject: [PATCH 09/75] feat: process one batch at a time instead of caching --- src/agent.py | 42 +++++++------- src/model.py | 8 +-- src/observation_utils.py | 1 - src/railway_utils.py | 6 +- src/replay_memory.py | 6 +- src/train.py | 120 ++++++++++++++++++--------------------- 6 files changed, 90 insertions(+), 93 deletions(-) diff --git a/src/agent.py b/src/agent.py index dbf4aaf..517952e 100644 --- a/src/agent.py +++ b/src/agent.py @@ -13,7 +13,7 @@ import os BUFFER_SIZE = 500_000 -BATCH_SIZE = 256 +BATCH_SIZE = 64 GAMMA = 0.998 TAU = 1e-3 LR = 2e-4 @@ -48,7 +48,7 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s self.optimizer = Optimizer(self.qnetwork_local.parameters(), lr=LR, weight_decay=1e-2) # Replay memory - self.memory = ReplayBuffer(BUFFER_SIZE) + self.memory = ReplayBuffer(BATCH_SIZE) self.t_step = 0 def reset(self): @@ -59,6 +59,7 @@ def reset(self): def act(self, state, eps=0.): agent_count = len(state) state = torch.stack(state, -1).unsqueeze(0).to(device) + state = torch.cat([state, torch.randn(1, 1, state.size(-1), device=device)], 1) self.qnetwork_local.eval() with torch.no_grad(): action_values = self.qnetwork_local(state) @@ -71,13 +72,15 @@ def act(self, state, eps=0.): def multi_act(self, states, eps=0.): agent_count = len(states[0]) - state = torch.stack([torch.stack(state, -1) for state in states], 0).to(device) + state = torch.stack([torch.stack(state, -1) if len(state) > 1 else state.unsqueeze(-1) + for state in states], 0).to(device) + state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) self.qnetwork_local.eval() with torch.no_grad(): action_values = self.qnetwork_local(state) # Epsilon-greedy action selection - return [[torch.argmax(act[:, :, i], 1).item() + return [[torch.argmax(act[:, i], 0).item() if random.random() > eps else torch.randint(self.action_size, ()).item() for i in range(agent_count)] @@ -86,26 +89,25 @@ def multi_act(self, states, eps=0.): # Record the results of the agent's action and update the model def step(self, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): - if not self.finished: - if agent_done: - reward = 1 - elif collision: - reward = -5 - else: - reward = step_reward - - # Save experience in replay memory - self.memory.push(state, action, reward, next_state, agent_done or episode_done) - self.finished = episode_done - - # Perform a gradient update every UPDATE_EVERY time steps - # self.t_step = (self.t_step + 1) % UPDATE_EVERY - if len(self.memory) > BATCH_SIZE * 20: - self.learn(*self.memory.sample(BATCH_SIZE, device)) + state = self.memory.stack(state).to(device).transpose(1, 2) + action = self.memory.stack(action).to(device) + reward = self.memory.stack([1 if ad + else (c - 5 if collision else step_reward) + for ad, c in zip(agent_done, collision)]).to(device) + next_state = self.memory.stack(next_state).to(device).transpose(1, 2) + dones = self.memory.stack([[v or episode_done for k, v in a.items() + if not hasattr(k, 'startswith') + or not k.startswith('_')] for a in agent_done]).to(device).float() + state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + next_state = torch.cat([next_state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + self.learn(state, action, reward, next_state, dones) def learn(self, states, actions, rewards, next_states, dones): self.qnetwork_local.train() + actions.squeeze_(-1) + dones.squeeze_(-1) + # Get expected Q values from local model Q_expected = self.qnetwork_local(states.squeeze(1)) diff --git a/src/model.py b/src/model.py index 08288cc..3d56331 100644 --- a/src/model.py +++ b/src/model.py @@ -283,8 +283,8 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): - model = torch.nn.Sequential(torch.nn.Conv1d(state_size, 11 * hidden_factor, 1, bias=False), - torch.nn.BatchNorm1d(11 * hidden_factor), - Mish(), - torch.nn.Conv1d(11*hidden_factor, action_size, 1)) + model = torch.nn.Sequential(torch.nn.Conv1d(state_size + 1, 11 * hidden_factor, 1, bias=False), + torch.nn.BatchNorm1d(11 * hidden_factor), + Mish(), + torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) return model diff --git a/src/observation_utils.py b/src/observation_utils.py index 61ea7c6..fee22cd 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -76,5 +76,4 @@ def normalize_observation(tree, max_depth, zero_center=True): if zero_center: data[:, :6].sub_(data[:, :6].mean()) data[:, 7:].sub_(data[:, 7:].mean()) - return data.flatten() diff --git a/src/railway_utils.py b/src/railway_utils.py index b17c025..a533fa2 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -43,9 +43,9 @@ def __call__(self, *args, **kwargs): # Helper function to load in precomputed railway networks def load_precomputed_railways(project_root, start_index): prefix = os.path.join(project_root, 'railroads') - suffix = f'_sum.pkl' - rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) - sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + suffix = f'_3x30x30.pkl' + sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) print(f"Working on {len(rail)} tracks") return rail, sched diff --git a/src/replay_memory.py b/src/replay_memory.py index 0fbe503..8a5ddd8 100644 --- a/src/replay_memory.py +++ b/src/replay_memory.py @@ -44,7 +44,9 @@ def sample(self, batch_size, device): actions = self.stack([e.action for e in experiences]).long().to(device) rewards = self.stack([e.reward for e in experiences]).float().to(device) next_states = self.stack([e.next_state for e in experiences]).float().to(device) - dones = self.stack([[v for k, v in e.done.items() if not hasattr(k, 'startswith') or not k.startswith('_')] + dones = self.stack([[v for k, v in e.done.items() + if not hasattr(k, 'startswith') + or not k.startswith('_')] for e in experiences]).float().to(device) return states, actions, rewards, next_states, dones @@ -55,6 +57,8 @@ def stack(self, states, dim=0): return torch.stack([self.stack(st, -1) for st in states], dim) if isinstance(states[0], torch.Tensor): return torch.stack(states, 0) + if isinstance(states[0], Iterable): + return torch.stack([self.stack(st, dim) for st in states], dim) return torch.tensor(states) return torch.tensor(states).view(len(states), 1) diff --git a/src/train.py b/src/train.py index 77b7aee..b098cc4 100644 --- a/src/train.py +++ b/src/train.py @@ -3,12 +3,10 @@ import time from pathlib import Path -import cv2 import numpy as np from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters -from flatland.envs.rail_env import RailEnv from flatland.envs.observations import GlobalObsForRailEnv -from flatland.utils.rendertools import RenderTool, AgentRenderVariant +from flatland.envs.rail_env import RailEnv from tensorboardX import SummaryWriter try: @@ -91,10 +89,9 @@ if flags.global_environment else TreeObservation(max_depth=flags.tree_depth)) ) +environments = [copy.copy(env) for _ in range(BATCH_SIZE)] # After training we want to render the results so we also load a renderer -env_renderer = RenderTool(env, gl="PILSVG", screen_width=800, screen_height=800, - agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) # Add some variables to keep track of the progress current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = 0 @@ -109,26 +106,20 @@ ACTIONS = {0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S'} -def is_collision(a): - if obs[a] is None: return False - is_junction = not isinstance(obs[a].childs['L'], float) or not isinstance(obs[a].childs['R'], float) +def is_collision(a, i): + if obs[i][a] is None: return False + is_junction = not isinstance(obs[i][a].childs['L'], float) or not isinstance(obs[i][a].childs['R'], float) - if not is_junction or env.agents[a].speed_data['position_fraction'] > 0: - action = ACTIONS[env.agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' - return obs[a].childs[action].num_agents_opposite_direction > 0 \ - and obs[a].childs[action].dist_other_agent_encountered <= 1 \ - and obs[a].childs[action].dist_other_agent_encountered < obs[a].childs[action].dist_unusable_switch + if not is_junction or environments[i].agents[a].speed_data['position_fraction'] > 0: + action = ACTIONS[environments[i].agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' + return obs[i][a].childs[action] != np.inf and obs[i][a].childs[action] != -np.inf\ + and obs[i][a].childs[action].num_agents_opposite_direction > 0 \ + and obs[i][a].childs[action].dist_other_agent_encountered <= 1 \ + and obs[i][a].childs[action].dist_other_agent_encountered < obs[i][a].childs[action].dist_unusable_switch else: return False -# Helper function to render the environment -def render(): - env_renderer.render_env(show_observations=False) - cv2.imshow('Render', cv2.cvtColor(env_renderer.get_image(), cv2.COLOR_BGR2RGB)) - cv2.waitKey(120) - - def get_means(x, y, c, s): return (x * 3 + c) / 4, (y * (s - 1) + c) / s @@ -138,70 +129,74 @@ def get_means(x, y, c, s): # Main training loop for episode in range(start + 1, flags.num_episodes + 1): agent.reset() - env_renderer.reset() obs, info = env.reset(True, True) + environments = [copy.copy(env) for _ in range(BATCH_SIZE)] + obs = [copy.deepcopy(obs) for _ in range(BATCH_SIZE)] + info = [copy.deepcopy(info) for _ in range(BATCH_SIZE)] score, steps_taken, collision = 0, 0, False - agent_obs = [normalize_observation(obs[a], flags.tree_depth, zero_center=True) - for a in obs.keys()] + agent_obs = [[normalize_observation(o[a], flags.tree_depth, zero_center=True) + for a in o.keys()] for o in obs] agent_obs_buffer = copy.deepcopy(agent_obs) - agent_count = len(agent_obs) - agent_action_buffer = [2] * agent_count + agent_count = len(agent_obs[0]) + agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] # Run an episode - max_steps = 8 * (env.width + env.height) + max_steps = 8 * env.width + env.height for step in range(max_steps): - update_values = [False] * agent_count - action_dict = {} + update_values = [[False] * agent_count for _ in range(BATCH_SIZE)] + action_dict = [{} for _ in range(BATCH_SIZE)] - if any(info['action_required']): - ret_action = agent.act(agent_obs, eps=eps) + if all(any(inf['action_required']) for inf in info): + ret_action = agent.multi_act(agent_obs, eps=eps) else: ret_action = update_values - for idx, act in enumerate(ret_action): - if info['action_required'][idx]: - action_dict[idx] = act - # action_dict[a] = np.random.randint(5) - update_values[idx] = True - steps_taken += 1 - else: - action_dict[idx] = 0 + for idx, act_list in enumerate(ret_action): + for sub_idx, act in enumerate(act_list): + if info[idx]['action_required'][sub_idx]: + action_dict[idx][sub_idx] = act + # action_dict[a] = np.random.randint(5) + update_values[idx][sub_idx] = True + steps_taken += 1 + else: + action_dict[idx][sub_idx] = 0 # Environment step - obs, rewards, done, info = env.step(action_dict) - score += sum(rewards.values()) / agent_count + obs, rewards, done, info = list(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) + score += sum(sum(r.values()) for r in rewards) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion - if step == max_steps - 1: - done['__all__'] = True - if any(is_collision(a) for a in obs): + all_done = step == (max_steps - 1) + if any(is_collision(a, i) for i, o in enumerate(obs) for a in o): collision = True # done['__all__'] = True # Update replay buffer and train agent - if flags.train and (any(update_values) or any(done) or done['__all__']): + if flags.train and (any(update_values) or all_done or all(any(d) for d in done)): agent.step(agent_obs_buffer, agent_action_buffer, agent_obs, done, - done['__all__'], - [is_collision(a) for a in range(agent_count)], + all_done, + [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)], flags.step_reward) - agent_obs_buffer = [o.clone() for o in agent_obs] - for key, value in action_dict.items(): - agent_action_buffer[key] = value - - for a in range(agent_count): - if obs[a]: - agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=True) - - # Render - # if flags.render_interval and episode % flags.render_interval == 0: - # if collision and all(agent.position for agent in env.agents): - # render() - # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) - # break - if done['__all__']: + agent_obs_buffer = copy.deepcopy(agent_obs) + for idx, act in enumerate(action_dict): + for key, value in act.items(): + agent_action_buffer[idx][key] = value + + for i in range(BATCH_SIZE): + for a in range(agent_count): + if obs[i][a]: + agent_obs[i][a] = normalize_observation(obs[i][a], flags.tree_depth, zero_center=True) + + # Render + # if flags.render_interval and episode % flags.render_interval == 0: + # if collision and all(agent.position for agent in env.agents): + # render() + # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) + # break + if all_done: break # Epsilon decay @@ -209,8 +204,6 @@ def get_means(x, y, c, s): eps = max(0.01, flags.epsilon_decay * eps) # Save some training statistics in their respective deques - tasks_finished = sum(done[i] for i in range(agent_count)) - current_done, mean_done = get_means(current_done, mean_done, tasks_finished / max(1, agent_count), episode) current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken, episode) @@ -219,7 +212,6 @@ def get_means(x, y, c, s): f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' - f' | Finished: {100 * current_done:6.2f}%, {100 * mean_done:6.2f}%' f' | Epsilon: {eps:.2f}' f' | Episode/s: {episode / (time.time() - start_time):.2f}s', end='') From 29fa930d778198b05c7a9695b0e257292ff8d4e6 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Thu, 9 Jul 2020 16:49:42 +0200 Subject: [PATCH 10/75] perf: use torch for of observation_utils.py --- src/observation_utils.py | 63 ++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/observation_utils.py b/src/observation_utils.py index fee22cd..4a7b642 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -6,8 +6,7 @@ except: from tree_observation import ACTIONS -ZERO_NODE = np.array([0] * 11) # For Q-Networks -INF_DISTANCE_NODE = np.array([0] * 6 + [np.inf] + [0] * 4) # For policy networks +ZERO_NODE = torch.zeros((11,)) # Helper function to detect collisions @@ -28,7 +27,7 @@ def create_tree_features(node, current_depth, max_depth, empty_node, data): data.extend([empty_node] * num_remaining_nodes) else: - data.append(np.array(tuple(node)[:-2])) + data.append(torch.FloatTensor(tuple(node)[:-2])) if node.childs: for direction in ACTIONS: create_tree_features(node.childs[direction], current_depth + 1, max_depth, empty_node, data) @@ -41,39 +40,41 @@ def create_tree_features(node, current_depth, max_depth, empty_node, data): # Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features -#@torch.jit.script -def norm_obs_clip(obs, normalize_to_range): - max_obs = obs[obs < 1000].max() - max_obs.clamp_(min=1) - max_obs.add_(1) +@torch.jit.script +def max_obs(obs): + out = obs[obs < 1000].max() + out.clamp_(min=1) + out.add_(1) + return out - min_obs = torch.zeros(1)[0] - if normalize_to_range.item(): - min_obs.add_(obs[obs >= 0].min().clamp(max=max_obs.item())) +@torch.jit.script +def wrap(data: torch.Tensor): + start = data[:, :6] + mid = data[:, 6] + max0 = max_obs(start) + max1 = max_obs(mid) - if max_obs == min_obs: - obs.div_(max_obs) - else: - obs.sub_(min_obs) - max_obs.sub_(min_obs) - obs.div_(max_obs) - return obs + min_obs = mid[mid >= 0].min() + + mid.sub_(min_obs) + max1.sub_(min_obs) + mid.div_(max1) + + start.div_(max0) + + data.clamp_(-1, 1) + + data[:, :6].sub_(data[:, :6].mean()) + data[:, 7:].sub_(data[:, 7:].mean()) # Normalize a tree observation def normalize_observation(tree, max_depth, zero_center=True): - empty_node = ZERO_NODE if zero_center else INF_DISTANCE_NODE - data = np.concatenate([create_tree_features(t, 0, max_depth, empty_node, []) for t in tree.values()] - if isinstance(tree, dict) else - create_tree_features(tree, 0, max_depth, empty_node, [])).reshape((-1, 11)) - data = torch.as_tensor(data).float() - - norm_obs_clip(data[:, :6], FALSE) - norm_obs_clip(data[:, 6], TRUE) - data.clamp_(-1, 1) + data = torch.cat([create_tree_features(t, 0, max_depth, ZERO_NODE, []) for t in tree.values()] + if isinstance(tree, dict) else + create_tree_features(tree, 0, max_depth, ZERO_NODE, []), 0).view((-1, 11)) + + wrap(data) - if zero_center: - data[:, :6].sub_(data[:, :6].mean()) - data[:, 7:].sub_(data[:, 7:].mean()) - return data.flatten() + return data.view(-1) From 4b40ae898ec121b661576334f6ce349abb976306 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Thu, 9 Jul 2020 17:57:35 +0200 Subject: [PATCH 11/75] perf: remove all numpy --- src/observation_utils.py | 40 ++++++++++++---------------------------- src/train.py | 14 ++++++-------- src/tree_observation.py | 19 +++++++++++-------- 3 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/observation_utils.py b/src/observation_utils.py index 4a7b642..89296fe 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -1,42 +1,23 @@ -import numpy as np import torch try: - from .tree_observation import ACTIONS + from .tree_observation import ACTIONS, negative_infinity, positive_infinity except: - from tree_observation import ACTIONS + from tree_observation import ACTIONS, negative_infinity, positive_infinity ZERO_NODE = torch.zeros((11,)) -# Helper function to detect collisions -def is_collision(obs): - return obs is not None \ - and isinstance(obs.childs['L'], float) \ - and isinstance(obs.childs['R'], float) \ - and obs.childs['F'].num_agents_opposite_direction > 0 \ - and obs.childs['F'].dist_other_agent_encountered <= 1 \ - and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_unusable_switch - # and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_to_next_branch - - # Recursively create numpy arrays for each tree node def create_tree_features(node, current_depth, max_depth, empty_node, data): - if node == -np.inf or node is None: - num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // (4 - 1) + if node == negative_infinity or node == positive_infinity or node is None: + num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // 3 data.extend([empty_node] * num_remaining_nodes) - else: data.append(torch.FloatTensor(tuple(node)[:-2])) if node.childs: - for direction in ACTIONS: - create_tree_features(node.childs[direction], current_depth + 1, max_depth, empty_node, data) - - return data - - -TRUE = torch.ones(1) -FALSE = torch.zeros(1) + any(create_tree_features(node.childs[direction], current_depth + 1, max_depth, empty_node, data) + for direction in ACTIONS) # Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features @@ -71,9 +52,12 @@ def wrap(data: torch.Tensor): # Normalize a tree observation def normalize_observation(tree, max_depth, zero_center=True): - data = torch.cat([create_tree_features(t, 0, max_depth, ZERO_NODE, []) for t in tree.values()] - if isinstance(tree, dict) else - create_tree_features(tree, 0, max_depth, ZERO_NODE, []), 0).view((-1, 11)) + data = [] + if isinstance(tree, dict): + any(create_tree_features(t, 0, max_depth, ZERO_NODE, data) for t in tree.values()) + else: + create_tree_features(tree, 0, max_depth, ZERO_NODE, data) + data = torch.cat(data, 0).view((-1, 11)) wrap(data) diff --git a/src/train.py b/src/train.py index b098cc4..759a50a 100644 --- a/src/train.py +++ b/src/train.py @@ -3,7 +3,6 @@ import time from pathlib import Path -import numpy as np from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -11,13 +10,13 @@ try: from .agent import Agent as DQN_Agent, device, BATCH_SIZE - from .tree_observation import TreeObservation - from .observation_utils import normalize_observation, is_collision + from .tree_observation import TreeObservation, negative_infinity, positive_infinity + from .observation_utils import normalize_observation from .railway_utils import load_precomputed_railways, create_random_railways except: from agent import Agent as DQN_Agent, device, BATCH_SIZE - from tree_observation import TreeObservation - from observation_utils import normalize_observation, is_collision + from tree_observation import TreeObservation, negative_infinity, positive_infinity + from observation_utils import normalize_observation from railway_utils import load_precomputed_railways, create_random_railways project_root = Path(__file__).resolve().parent.parent @@ -53,13 +52,12 @@ flags = parser.parse_args() # Seeded RNG so we can replicate our results -np.random.seed(1) # Create a tensorboard SummaryWriter summary = SummaryWriter(f'tensorboard/dqn/agents: {flags.num_agents}, tree_depth: {flags.tree_depth}') # Calculate the state size based on the number of nodes in the tree observation num_features_per_node = 11 # env.obs_builder.observation_dim -num_nodes = sum(np.power(4, i) for i in range(flags.tree_depth + 1)) +num_nodes = int('1' * (flags.tree_depth + 1), 4) state_size = num_nodes * num_features_per_node action_size = 5 # Load an RL agent and initialize it from checkpoint if necessary @@ -112,7 +110,7 @@ def is_collision(a, i): if not is_junction or environments[i].agents[a].speed_data['position_fraction'] > 0: action = ACTIONS[environments[i].agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' - return obs[i][a].childs[action] != np.inf and obs[i][a].childs[action] != -np.inf\ + return obs[i][a].childs[action] != negative_infinity and obs[i][a].childs[action] != positive_infinity \ and obs[i][a].childs[action].num_agents_opposite_direction > 0 \ and obs[i][a].childs[action].dist_other_agent_encountered <= 1 \ and obs[i][a].childs[action].dist_other_agent_encountered < obs[i][a].childs[action].dist_unusable_switch diff --git a/src/tree_observation.py b/src/tree_observation.py index d126eb7..1f3e528 100644 --- a/src/tree_observation.py +++ b/src/tree_observation.py @@ -1,6 +1,6 @@ from collections import defaultdict -import numpy as np +import torch from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus @@ -9,6 +9,9 @@ ACTIONS = ['L', 'F', 'R', 'B'] Node = TreeObsForRailEnv.Node +positive_infinity = (torch.ones(1) / torch.zeros(1))[0] +negative_infinity = -positive_infinity + def first(list): return next(iter(list)) @@ -159,9 +162,9 @@ def get_many(self, handles=[]): for start, _, start_direction, distance in self.edge_positions[ (*agent.position, direction)]: edge_distance = distance if direction == agent.direction else \ - start.edges[start_direction][1] - distance + start.edges[start_direction][1] - distance edge_dict[(*start.position, start_direction)][agent.handle] = ( - distance, agent.speed_data['speed']) + distance, agent.speed_data['speed']) # Check for malfunctions if agent.malfunction_data['malfunction']: @@ -190,7 +193,7 @@ def get(self, handle): return None # The root node contains information about the agent itself - children = {x: -np.inf for x in ACTIONS} + children = {x: negative_infinity for x in ACTIONS} dist_min_to_target = self.env.distance_map.get()[(handle, *agent_position, agent.direction)] agent_malfunctioning, agent_speed = agent.malfunction_data['malfunction'], agent.speed_data['speed'] root_tree_node = Node(0, 0, 0, 0, 0, 0, dist_min_to_target, 0, 0, agent_malfunctioning, agent_speed, 0, @@ -225,8 +228,8 @@ def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, targets, agents, minor_nodes = [], [], [] edge_length, max_malfunction_length = 0, 0 num_agents_same_direction, num_agents_other_direction = 0, 0 - distance_to_minor_node, distance_to_other_agent = np.inf, np.inf - distance_to_own_target, distance_to_other_target = np.inf, np.inf + distance_to_minor_node, distance_to_other_agent = positive_infinity, positive_infinity + distance_to_own_target, distance_to_other_target = positive_infinity, positive_infinity min_agent_speed, num_agent_departures = 1.0, 0 # Skip ahead until we get to a major node, logging any agents on the tracks along the way @@ -295,7 +298,7 @@ def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, # Create a new tree node and populate its children if depth < self.max_depth: - children = {x: -np.inf for x in ACTIONS} + children = {x: negative_infinity for x in ACTIONS} if not self.is_own_target(agent, node): for direction in node.edges.keys(): children[get_action(orientation, direction)] = \ @@ -308,7 +311,7 @@ def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, return Node(dist_own_target_encountered=total_distance + distance_to_own_target, dist_other_target_encountered=total_distance + distance_to_other_target, dist_other_agent_encountered=total_distance + distance_to_other_agent, - dist_potential_conflict=np.inf, + dist_potential_conflict=positive_infinity, dist_unusable_switch=total_distance + distance_to_minor_node, dist_to_next_branch=total_distance + edge_length, dist_min_to_target=self.env.distance_map.get()[(agent.handle, *node.position, orientation)] or 0, From bc2915141a560a7e522165424c9707f383a0a2b6 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Thu, 9 Jul 2020 18:40:08 +0200 Subject: [PATCH 12/75] perf: use iterative solution instead of recursive one (~+20%) --- src/observation_utils.py | 23 ++++++++++++----------- src/train.py | 6 +++--- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/observation_utils.py b/src/observation_utils.py index 89296fe..d1893cf 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -9,15 +9,16 @@ # Recursively create numpy arrays for each tree node -def create_tree_features(node, current_depth, max_depth, empty_node, data): - if node == negative_infinity or node == positive_infinity or node is None: - num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // 3 - data.extend([empty_node] * num_remaining_nodes) - else: - data.append(torch.FloatTensor(tuple(node)[:-2])) - if node.childs: - any(create_tree_features(node.childs[direction], current_depth + 1, max_depth, empty_node, data) - for direction in ACTIONS) +def create_tree_features(node, max_depth, data): + nodes = [(node, 0)] + for node, current_depth in nodes: + if node == negative_infinity or node == positive_infinity or node is None: + num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // 3 + data.extend([ZERO_NODE] * num_remaining_nodes) + else: + data.append(torch.FloatTensor(tuple(node)[:-2])) + if node.childs: + nodes.extend((node.childs[direction], current_depth + 1) for direction in ACTIONS) # Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features @@ -54,9 +55,9 @@ def wrap(data: torch.Tensor): def normalize_observation(tree, max_depth, zero_center=True): data = [] if isinstance(tree, dict): - any(create_tree_features(t, 0, max_depth, ZERO_NODE, data) for t in tree.values()) + any(create_tree_features(t, max_depth, data) for t in tree.values()) else: - create_tree_features(tree, 0, max_depth, ZERO_NODE, data) + create_tree_features(tree, max_depth, data) data = torch.cat(data, 0).view((-1, 11)) wrap(data) diff --git a/src/train.py b/src/train.py index 759a50a..de48950 100644 --- a/src/train.py +++ b/src/train.py @@ -204,14 +204,14 @@ def get_means(x, y, c, s): # Save some training statistics in their respective deques current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) - current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken, episode) + current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) print(f'\rEpisode {episode:<5}' f' | Score: {current_score:.4f}, {mean_score:.4f}' - f' | Steps: {current_steps:6.1f}, {mean_steps:6.1f}' + f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' f' | Epsilon: {eps:.2f}' - f' | Episode/s: {episode / (time.time() - start_time):.2f}s', end='') + f' | Episode/s: {episode / (time.time() - start_time):.4f}s', end='') if episode % flags.report_interval == 0: print("") From e3e092081711217b4291511102c253da9ecaa75b Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Thu, 9 Jul 2020 21:33:32 +0200 Subject: [PATCH 13/75] perf: add multiprocessing (+30%) --- src/agent.py | 44 +++++++++++++++++++++++----------------- src/model.py | 4 ++-- src/observation_utils.py | 22 ++++++++++++-------- src/train.py | 32 ++++++++++++++--------------- 4 files changed, 56 insertions(+), 46 deletions(-) diff --git a/src/agent.py b/src/agent.py index 517952e..897fdfe 100644 --- a/src/agent.py +++ b/src/agent.py @@ -17,7 +17,7 @@ GAMMA = 0.998 TAU = 1e-3 LR = 2e-4 -UPDATE_EVERY = 1 +UPDATE_EVERY = 16 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -49,6 +49,7 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s # Replay memory self.memory = ReplayBuffer(BATCH_SIZE) + self.stack = [[] for _ in range(5)] self.t_step = 0 def reset(self): @@ -70,11 +71,9 @@ def act(self, state, eps=0.): else torch.randint(self.action_size, ()).item() for i in range(agent_count)] - def multi_act(self, states, eps=0.): - agent_count = len(states[0]) - state = torch.stack([torch.stack(state, -1) if len(state) > 1 else state.unsqueeze(-1) - for state in states], 0).to(device) - state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + def multi_act(self, state, eps=0.): + agent_count = state.size(-1) + state = torch.cat([state.to(device), torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) self.qnetwork_local.eval() with torch.no_grad(): action_values = self.qnetwork_local(state) @@ -89,18 +88,25 @@ def multi_act(self, states, eps=0.): # Record the results of the agent's action and update the model def step(self, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): - state = self.memory.stack(state).to(device).transpose(1, 2) - action = self.memory.stack(action).to(device) - reward = self.memory.stack([1 if ad - else (c - 5 if collision else step_reward) - for ad, c in zip(agent_done, collision)]).to(device) - next_state = self.memory.stack(next_state).to(device).transpose(1, 2) - dones = self.memory.stack([[v or episode_done for k, v in a.items() - if not hasattr(k, 'startswith') - or not k.startswith('_')] for a in agent_done]).to(device).float() - state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) - next_state = torch.cat([next_state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) - self.learn(state, action, reward, next_state, dones) + if len(self.stack) >= UPDATE_EVERY - 1: + action = self.memory.stack(self.stack[1]).to(device) + reward = self.memory.stack([1 if ad + else (c - 5 if collision else step_reward) + for ad, c in zip(self.stack[2], self.stack[4])]).to(device) + dones = self.memory.stack(self.stack[3]).to(device).float() + state = state.to(device) + next_state = next_state.to(device) + state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + next_state = torch.cat([next_state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + self.learn(state, action, reward, next_state, dones) + else: + self.stack[0].append(state) + self.stack[1].append(action) + self.stack[2].append(next_state) + self.stack[3].append([[v or episode_done for k, v in a.items() + if not hasattr(k, 'startswith') + or not k.startswith('_')] for a in agent_done]) + self.stack[4].append(collision) def learn(self, states, actions, rewards, next_states, dones): self.qnetwork_local.train() @@ -117,7 +123,7 @@ def learn(self, states, actions, rewards, next_states, dones): # Compute loss and perform a gradient step self.optimizer.zero_grad() - loss = (rewards.unsqueeze(-1) + GAMMA * Q_targets_next * (1 - dones.unsqueeze(-2)) - Q_expected).square().mean() + loss = (GAMMA * Q_targets_next * (1 - dones.unsqueeze(-2)) - Q_expected - rewards.unsqueeze(-1)).square().mean() loss.backward() self.optimizer.step() diff --git a/src/model.py b/src/model.py index 3d56331..c8b91bc 100644 --- a/src/model.py +++ b/src/model.py @@ -283,8 +283,8 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): - model = torch.nn.Sequential(torch.nn.Conv1d(state_size + 1, 11 * hidden_factor, 1, bias=False), + model = torch.nn.Sequential(WeightDropConv(state_size + 1, 11 * hidden_factor, bias=False), torch.nn.BatchNorm1d(11 * hidden_factor), Mish(), - torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) + WeightDropConv(11 * hidden_factor, action_size)) return model diff --git a/src/observation_utils.py b/src/observation_utils.py index d1893cf..6acc155 100644 --- a/src/observation_utils.py +++ b/src/observation_utils.py @@ -5,7 +5,7 @@ except: from tree_observation import ACTIONS, negative_infinity, positive_infinity -ZERO_NODE = torch.zeros((11,)) +ZERO_NODE = torch.zeros((1, 11)) # Recursively create numpy arrays for each tree node @@ -14,9 +14,9 @@ def create_tree_features(node, max_depth, data): for node, current_depth in nodes: if node == negative_infinity or node == positive_infinity or node is None: num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // 3 - data.extend([ZERO_NODE] * num_remaining_nodes) + data.append(ZERO_NODE.expand(num_remaining_nodes, -1)) else: - data.append(torch.FloatTensor(tuple(node)[:-2])) + data.append(torch.FloatTensor(node[:-2]).view(1, 11)) if node.childs: nodes.extend((node.childs[direction], current_depth + 1) for direction in ACTIONS) @@ -52,14 +52,18 @@ def wrap(data: torch.Tensor): # Normalize a tree observation -def normalize_observation(tree, max_depth, zero_center=True): +def normalize_observation(tree, max_depth, shared_tensor, inner_index): data = [] if isinstance(tree, dict): - any(create_tree_features(t, max_depth, data) for t in tree.values()) - else: - create_tree_features(tree, max_depth, data) - data = torch.cat(data, 0).view((-1, 11)) + tree = tree.values() + for t in tree: + data.append([]) + if isinstance(t, dict): + any(create_tree_features(d, max_depth, data[-1]) for d in t.values()) + else: + create_tree_features(t, max_depth, data[-1]) + data = torch.stack([torch.cat(dat, 0) for dat in data], -1) wrap(data) - return data.view(-1) + shared_tensor[inner_index] = data.flatten(0, 1) diff --git a/src/train.py b/src/train.py index de48950..94fc423 100644 --- a/src/train.py +++ b/src/train.py @@ -2,7 +2,8 @@ import copy import time from pathlib import Path - +from pathos import multiprocessing +import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -36,7 +37,7 @@ parser.add_argument("--grid-width", type=int, default=50, help="Number of columns in the environment grid") parser.add_argument("--grid-height", type=int, default=50, help="Number of rows in the environment grid") parser.add_argument("--num-agents", type=int, default=5, help="Number of agents in each episode") -parser.add_argument("--tree-depth", type=int, default=2, help="Depth of the observation tree") +parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--model-depth", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--hidden-factor", type=int, default=5, help="Depth of the observation tree") parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") @@ -46,7 +47,7 @@ parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") -parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") +parser.add_argument("--step-reward", type=float, default=-1, help="Depth of the observation tree") parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") flags = parser.parse_args() @@ -92,7 +93,7 @@ # After training we want to render the results so we also load a renderer # Add some variables to keep track of the progress -current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = 0 +current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = current_taken = mean_taken = 0 agent_action_buffer = [] start_time = time.time() @@ -123,6 +124,7 @@ def get_means(x, y, c, s): episode = 0 +POOL = multiprocessing.Pool() # Main training loop for episode in range(start + 1, flags.num_episodes + 1): @@ -132,11 +134,10 @@ def get_means(x, y, c, s): obs = [copy.deepcopy(obs) for _ in range(BATCH_SIZE)] info = [copy.deepcopy(info) for _ in range(BATCH_SIZE)] score, steps_taken, collision = 0, 0, False - - agent_obs = [[normalize_observation(o[a], flags.tree_depth, zero_center=True) - for a in o.keys()] for o in obs] - agent_obs_buffer = copy.deepcopy(agent_obs) - agent_count = len(agent_obs[0]) + agent_count = len(obs[0]) + agent_obs = torch.zeros((BATCH_SIZE, state_size, agent_count)) + POOL.starmap(func=normalize_observation, iterable=((o, flags.tree_depth, agent_obs, i) for i, o in enumerate(obs))) + agent_obs_buffer = agent_obs.clone() agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] # Run an episode @@ -164,7 +165,7 @@ def get_means(x, y, c, s): score += sum(sum(r.values()) for r in rewards) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion - all_done = step == (max_steps - 1) + all_done = step == (max_steps - 1) or any(d['__all__'] for d in done) if any(is_collision(a, i) for i, o in enumerate(obs) for a in o): collision = True # done['__all__'] = True @@ -178,16 +179,13 @@ def get_means(x, y, c, s): all_done, [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)], flags.step_reward) - agent_obs_buffer = copy.deepcopy(agent_obs) + agent_obs_buffer = agent_obs.clone() for idx, act in enumerate(action_dict): for key, value in act.items(): agent_action_buffer[idx][key] = value - for i in range(BATCH_SIZE): - for a in range(agent_count): - if obs[i][a]: - agent_obs[i][a] = normalize_observation(obs[i][a], flags.tree_depth, zero_center=True) - + POOL.starmap(func=normalize_observation, + iterable=((o, flags.tree_depth, agent_obs, i) for i, o in enumerate(obs))) # Render # if flags.render_interval and episode % flags.render_interval == 0: # if collision and all(agent.position for agent in env.agents): @@ -205,10 +203,12 @@ def get_means(x, y, c, s): current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) + current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) print(f'\rEpisode {episode:<5}' f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' + f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' f' | Epsilon: {eps:.2f}' f' | Episode/s: {episode / (time.time() - start_time):.4f}s', end='') From c966052b51ff7de103759fe7a365ef14c1d55efe Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 12:21:56 +0200 Subject: [PATCH 14/75] perf: init removal of redundant vars --- src/tree_observation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tree_observation.py b/src/tree_observation.py index 1f3e528..a17524b 100644 --- a/src/tree_observation.py +++ b/src/tree_observation.py @@ -70,11 +70,11 @@ def reset(self): # Now we create a graph representation of the rail network, starting from this node transitions = self.get_all_transitions(position) - root_nodes = {t: RailNode(position, t, self.is_target(position)) for t in transitions if t} - self.graph = {(*position, d): root_nodes[t] for d, t in enumerate(transitions) if t} + root_nodes = [RailNode(position, t, self.is_target(position)) for t in transitions if t] + self.graph = {(*position, d): n for d, n in enumerate(root_nodes)} - for transitions, node in root_nodes.items(): - for direction in transitions: + for node in root_nodes: + for direction in node.edge_directions: self.explore_branch(node, get_new_position(position, direction), direction) def explore_branch(self, node, position, direction): From 3bf53d43d94b0ec11b095900e304307b68eb4b67 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 12:22:40 +0200 Subject: [PATCH 15/75] Revert "perf: init removal of redundant vars" This reverts commit c966052b --- src/tree_observation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tree_observation.py b/src/tree_observation.py index a17524b..1f3e528 100644 --- a/src/tree_observation.py +++ b/src/tree_observation.py @@ -70,11 +70,11 @@ def reset(self): # Now we create a graph representation of the rail network, starting from this node transitions = self.get_all_transitions(position) - root_nodes = [RailNode(position, t, self.is_target(position)) for t in transitions if t] - self.graph = {(*position, d): n for d, n in enumerate(root_nodes)} + root_nodes = {t: RailNode(position, t, self.is_target(position)) for t in transitions if t} + self.graph = {(*position, d): root_nodes[t] for d, t in enumerate(transitions) if t} - for node in root_nodes: - for direction in node.edge_directions: + for transitions, node in root_nodes.items(): + for direction in transitions: self.explore_branch(node, get_new_position(position, direction), direction) def explore_branch(self, node, position, direction): From 3425712140328de6f65af5db81445b17dcf64c7c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 12:23:12 +0200 Subject: [PATCH 16/75] perf: cythonize normalizer --- src/cythonize.sh | 6 ++++ src/normalize_output_data.py | 33 +++++++++++++++++ src/observation_utils.py | 69 ------------------------------------ src/observation_utils.pyx | 44 +++++++++++++++++++++++ src/train.py | 57 +++++++++++++++++++---------- 5 files changed, 122 insertions(+), 87 deletions(-) create mode 100644 src/cythonize.sh create mode 100644 src/normalize_output_data.py delete mode 100644 src/observation_utils.py create mode 100644 src/observation_utils.pyx diff --git a/src/cythonize.sh b/src/cythonize.sh new file mode 100644 index 0000000..e08d63d --- /dev/null +++ b/src/cythonize.sh @@ -0,0 +1,6 @@ +cython observation_utils.pyx -3 -Wextra -D +cmd="gcc-7 observation_utils.c `python3-config --cflags --ldflags --includes --libs` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o observation_utils.so" +echo "Executing $cmd" +$cmd +echo "Testing compilation.." +python3 -c "import observation_utils" diff --git a/src/normalize_output_data.py b/src/normalize_output_data.py new file mode 100644 index 0000000..b08b97b --- /dev/null +++ b/src/normalize_output_data.py @@ -0,0 +1,33 @@ +import torch + +#torch.jit.optimized_execution(True) + + +def wrap(data: torch.Tensor): + start = data[:, :, :6] + mid = data[:, :, 6] + + max0 = torch.where(start < 1000, start, torch.zeros_like(start)) + max0 = max0.max(dim=1, keepdim=True)[0] + max0.clamp_(min=1) + + max1 = torch.where(mid < 1000, mid, torch.zeros_like(mid)) + max1 = max1.max(dim=1, keepdim=True)[0] + max1.clamp_(min=1) + + min_mid = torch.where(mid >= 0, mid, torch.zeros_like(mid)) + min_obs = min_mid.min(dim=1, keepdim=True)[0] + + mid.sub_(min_obs) + max1.sub_(min_obs) + mid.div_(max1) + + start.div_(max0) + + data.clamp_(-1, 1) + + data[:, :, :6].sub_(data[:, :, :6].mean()) + data[:, :, 7:].sub_(data[:, :, 7:].mean()) + data.detach_() + +wrap = torch.jit.script(wrap) diff --git a/src/observation_utils.py b/src/observation_utils.py deleted file mode 100644 index 6acc155..0000000 --- a/src/observation_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch - -try: - from .tree_observation import ACTIONS, negative_infinity, positive_infinity -except: - from tree_observation import ACTIONS, negative_infinity, positive_infinity - -ZERO_NODE = torch.zeros((1, 11)) - - -# Recursively create numpy arrays for each tree node -def create_tree_features(node, max_depth, data): - nodes = [(node, 0)] - for node, current_depth in nodes: - if node == negative_infinity or node == positive_infinity or node is None: - num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // 3 - data.append(ZERO_NODE.expand(num_remaining_nodes, -1)) - else: - data.append(torch.FloatTensor(node[:-2]).view(1, 11)) - if node.childs: - nodes.extend((node.childs[direction], current_depth + 1) for direction in ACTIONS) - - -# Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features -@torch.jit.script -def max_obs(obs): - out = obs[obs < 1000].max() - out.clamp_(min=1) - out.add_(1) - return out - - -@torch.jit.script -def wrap(data: torch.Tensor): - start = data[:, :6] - mid = data[:, 6] - max0 = max_obs(start) - max1 = max_obs(mid) - - min_obs = mid[mid >= 0].min() - - mid.sub_(min_obs) - max1.sub_(min_obs) - mid.div_(max1) - - start.div_(max0) - - data.clamp_(-1, 1) - - data[:, :6].sub_(data[:, :6].mean()) - data[:, 7:].sub_(data[:, 7:].mean()) - - -# Normalize a tree observation -def normalize_observation(tree, max_depth, shared_tensor, inner_index): - data = [] - if isinstance(tree, dict): - tree = tree.values() - for t in tree: - data.append([]) - if isinstance(t, dict): - any(create_tree_features(d, max_depth, data[-1]) for d in t.values()) - else: - create_tree_features(t, max_depth, data[-1]) - data = torch.stack([torch.cat(dat, 0) for dat in data], -1) - - wrap(data) - - shared_tensor[inner_index] = data.flatten(0, 1) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx new file mode 100644 index 0000000..85d2fce --- /dev/null +++ b/src/observation_utils.pyx @@ -0,0 +1,44 @@ +import torch + +try: + from .tree_observation import ACTIONS, negative_infinity, positive_infinity +except: + from tree_observation import ACTIONS, negative_infinity, positive_infinity + + +ZERO_NODE = torch.zeros((1, 11)) + +# Recursively create numpy arrays for each tree node +cpdef create_tree_features(node, int max_depth, list data): + cdef list nodes = [(node, 0)] + cdef int current_depth = 0 + for node, current_depth in nodes: + if node == negative_infinity or node == positive_infinity or node is None: + data.append(ZERO_NODE.expand((4 ** (max_depth - current_depth + 1) - 1) // 3, -1)) + else: + data.append(torch.FloatTensor(node[:-2]).view(1, 11)) + if node.childs: + for direction in ACTIONS: + nodes.append((node.childs[direction], current_depth + 1)) + +# Normalize a tree observation +cpdef normalize_observation(tuple observations, int max_depth, shared_tensor, int starting_index): + cdef list data = [] + cdef int i = 0 + for i, tree in enumerate(observations, 1): + if tree is None: + break + data.append([]) + if isinstance(tree, dict): + tree = tree.values() + for t in tree: + data[-1].append([]) + if isinstance(t, dict): + for d in t.values(): + create_tree_features(d, max_depth, data[-1][-1]) + else: + create_tree_features(t, max_depth, data[-1][-1]) + + shared_tensor[starting_index:starting_index + i] = torch.stack([torch.stack([torch.cat(dat, 0) + for dat in tree if dat != []], -1) + for tree in data if tree != []], 0) \ No newline at end of file diff --git a/src/train.py b/src/train.py index 94fc423..a9af96d 100644 --- a/src/train.py +++ b/src/train.py @@ -1,21 +1,24 @@ import argparse import copy import time +from itertools import zip_longest from pathlib import Path -from pathos import multiprocessing + import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv -from tensorboardX import SummaryWriter +from pathos import multiprocessing try: from .agent import Agent as DQN_Agent, device, BATCH_SIZE + from .normalize_output_data import wrap from .tree_observation import TreeObservation, negative_infinity, positive_infinity from .observation_utils import normalize_observation from .railway_utils import load_precomputed_railways, create_random_railways except: from agent import Agent as DQN_Agent, device, BATCH_SIZE + from normalize_output_data import wrap from tree_observation import TreeObservation, negative_infinity, positive_infinity from observation_utils import normalize_observation from railway_utils import load_precomputed_railways, create_random_railways @@ -49,13 +52,13 @@ parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") parser.add_argument("--step-reward", type=float, default=-1, help="Depth of the observation tree") parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") +parser.add_argument("--threads", type=int, default=1, help="Depth of the observation tree") flags = parser.parse_args() # Seeded RNG so we can replicate our results # Create a tensorboard SummaryWriter -summary = SummaryWriter(f'tensorboard/dqn/agents: {flags.num_agents}, tree_depth: {flags.tree_depth}') # Calculate the state size based on the number of nodes in the tree observation num_features_per_node = 11 # env.obs_builder.observation_dim num_nodes = int('1' * (flags.tree_depth + 1), 4) @@ -123,6 +126,24 @@ def get_means(x, y, c, s): return (x * 3 + c) / 4, (y * (s - 1) + c) / s +chunk_size = (BATCH_SIZE + 1) // flags.threads + + +def chunk(obj, size): + return zip_longest(*[iter(obj)] * size, fillvalue=None) + + +if flags.threads > 1: + def normalize(observation, target_tensor): + POOL.starmap(func=normalize_observation, + iterable=((o, flags.tree_depth, target_tensor, i * chunk_size) + for i, o in enumerate(chunk(observation, chunk_size)))) + wrap(target_tensor) +else: + def normalize(observation, target_tensor): + normalize_observation(observation, flags.tree_depth, target_tensor, 0) + wrap(target_tensor) + episode = 0 POOL = multiprocessing.Pool() @@ -131,12 +152,12 @@ def get_means(x, y, c, s): agent.reset() obs, info = env.reset(True, True) environments = [copy.copy(env) for _ in range(BATCH_SIZE)] - obs = [copy.deepcopy(obs) for _ in range(BATCH_SIZE)] + obs = tuple(copy.deepcopy(obs) for _ in range(BATCH_SIZE)) info = [copy.deepcopy(info) for _ in range(BATCH_SIZE)] score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) - agent_obs = torch.zeros((BATCH_SIZE, state_size, agent_count)) - POOL.starmap(func=normalize_observation, iterable=((o, flags.tree_depth, agent_obs, i) for i, o in enumerate(obs))) + agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) + normalize(obs, agent_obs) agent_obs_buffer = agent_obs.clone() agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] @@ -147,7 +168,7 @@ def get_means(x, y, c, s): action_dict = [{} for _ in range(BATCH_SIZE)] if all(any(inf['action_required']) for inf in info): - ret_action = agent.multi_act(agent_obs, eps=eps) + ret_action = agent.multi_act(agent_obs.flatten(1, 2), eps=eps) else: ret_action = update_values for idx, act_list in enumerate(ret_action): @@ -161,20 +182,20 @@ def get_means(x, y, c, s): action_dict[idx][sub_idx] = 0 # Environment step - obs, rewards, done, info = list(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) + obs, rewards, done, info = tuple(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) score += sum(sum(r.values()) for r in rewards) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion - all_done = step == (max_steps - 1) or any(d['__all__'] for d in done) + all_done = (step == (max_steps - 1)) or any(d['__all__'] for d in done) if any(is_collision(a, i) for i, o in enumerate(obs) for a in o): collision = True # done['__all__'] = True # Update replay buffer and train agent if flags.train and (any(update_values) or all_done or all(any(d) for d in done)): - agent.step(agent_obs_buffer, + agent.step(agent_obs_buffer.flatten(1, 2), agent_action_buffer, - agent_obs, + agent_obs.flatten(1, 2), done, all_done, [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)], @@ -184,34 +205,34 @@ def get_means(x, y, c, s): for key, value in act.items(): agent_action_buffer[idx][key] = value - POOL.starmap(func=normalize_observation, - iterable=((o, flags.tree_depth, agent_obs, i) for i, o in enumerate(obs))) + if all_done: + break + + normalize(obs, agent_obs) + # Render # if flags.render_interval and episode % flags.render_interval == 0: # if collision and all(agent.position for agent in env.agents): # render() # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break - if all_done: - break # Epsilon decay if flags.train: eps = max(0.01, flags.epsilon_decay * eps) - # Save some training statistics in their respective deques current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) - print(f'\rEpisode {episode:<5}' + print(f'\rBatch {episode:<5} - Episode {BATCH_SIZE*episode:<5}' f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' f' | Epsilon: {eps:.2f}' - f' | Episode/s: {episode / (time.time() - start_time):.4f}s', end='') + f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):.4f}s', end='') if episode % flags.report_interval == 0: print("") From 6a85249521252f65e473f081e54dd8e0639c66ee Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 13:45:28 +0200 Subject: [PATCH 17/75] perf: cythonize tree observation --- src/normalize_output_data.py | 4 +- src/observation_utils.pyx | 371 ++++++++++++++++++++++++++++++++++- src/train.py | 11 +- src/tree_observation.py | 341 -------------------------------- 4 files changed, 375 insertions(+), 352 deletions(-) diff --git a/src/normalize_output_data.py b/src/normalize_output_data.py index b08b97b..aa1de4a 100644 --- a/src/normalize_output_data.py +++ b/src/normalize_output_data.py @@ -1,6 +1,7 @@ import torch -#torch.jit.optimized_execution(True) + +torch.jit.optimized_execution(True) def wrap(data: torch.Tensor): @@ -30,4 +31,5 @@ def wrap(data: torch.Tensor): data[:, :, 7:].sub_(data[:, :, 7:].mean()) data.detach_() + wrap = torch.jit.script(wrap) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index 85d2fce..333a238 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -1,9 +1,370 @@ +from collections import defaultdict + import torch +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.observations import TreeObsForRailEnv + +cdef list ACTIONS = ['L', 'F', 'R', 'B'] +Node = TreeObsForRailEnv.Node + +cdef int positive_infinity = int(1e5) +cdef int negative_infinity = -positive_infinity + +def first(iterable): + for elem in iterable: + return elem + +cpdef bint _check_len1(tuple obj): + return len(obj) > 1 + +cpdef str get_action(int orientation, int direction): + return ACTIONS[(direction - orientation + 1 + 4) % 4] + +cpdef int get_direction(int orientation, int action): + if action == 1: + return (orientation + 4 - 1) % 4 + elif action == 3: + return (orientation + 1) % 4 + else: + return orientation + +cdef class RailNode: + cdef public dict edges + cdef public tuple position + cdef public tuple edge_directions + cdef public int is_target + def __init__(self, tuple position, tuple edge_directions, int is_target): + self.edges = {} + self.position = position + self.edge_directions = edge_directions + self.is_target = is_target + + def __repr__(self): + return f'RailNode({self.position}, {len(self.edges)})' + + +class TreeObservation(ObservationBuilder): + def __init__(self, max_depth): + super().__init__() + self.max_depth = max_depth + self.observation_dim = 11 + + # Create a graph representation of the current rail network + + def reset(self): + self.target_positions = {agent.target: 1 for agent in self.env.agents} + self.edge_positions = defaultdict(list) # (cell.position, direction) -> [(start, end, direction, distance)] + self.edge_paths = defaultdict(list) # (node.position, direction) -> [(cell.position, direction)] + + # First, we find a node by starting at one of the agents and following the rails until we reach a junction + agent = first(self.env.agents) + cpdef tuple position = tuple(agent.initial_position) + cpdef int direction = agent.direction + cdef bint out + while True: + try: + out = self.is_junction(position) or self.is_target(position) + except IndexError: + break + if not out: + break + direction = first(self.get_possible_transitions(position, direction)) + position = get_new_position(position, direction) + + # Now we create a graph representation of the rail network, starting from this node + cdef tuple transitions = self.get_all_transitions(position) + cdef dict root_nodes = {t: RailNode(position, t, self.is_target(position)) for t in transitions if t} + self.graph = {(*position, d): root_nodes[t] for d, t in enumerate(transitions) if t} + + for transitions, node in root_nodes.items(): + for direction in transitions: + self.explore_branch(node, get_new_position(position, direction), direction) + + def explore_branch(self, RailNode node, tuple position, int direction): + cdef int original_direction = direction + cdef dict edge_positions = {} + cdef int distance = 1 + cdef int next_direction = 0 + cdef int idx = 0 + cdef tuple transition = tuple() + cdef tuple key = tuple() + + # Explore until we find a junction + while not self.is_junction(position) and not self.is_target(position): + next_direction = first(self.get_possible_transitions(position, direction)) + edge_positions[(*position, direction)] = (distance, next_direction) + position = get_new_position(position, next_direction) + direction = next_direction + distance += 1 + + # Create any nodes that aren't in the graph yet + cdef tuple transitions = self.get_all_transitions(position) + cdef bint is_target = self.is_target(position) + cdef dict nodes = {transition: RailNode(position, transition, is_target) + for idx, transition in enumerate(transitions) + if transition and (*position, idx) not in self.graph} + + for idx, transition in enumerate(transitions): + if transition in nodes: + self.graph[(*position, idx)] = nodes[transition] + + # Connect the previous node to the next one, and update self.edge_positions + cdef RailNode next_node = self.graph[(*position, direction)] + node.edges[original_direction] = (next_node, distance) + for key, (distance, next_direction) in edge_positions.items(): + self.edge_positions[key].append((node, next_node, original_direction, distance)) + self.edge_paths[node.position, original_direction].append((*key, next_direction)) + + # Call ourselves recursively since we're exploring depth-first + for transitions, node in nodes.items(): + for direction in transitions: + self.explore_branch(node, get_new_position(position, direction), direction) + + # Create a tree observation for each agent, based on the graph we created earlier + + def get_many(self, list handles=[]): + self.nodes_with_agents_going, self.edges_with_agents_going = {}, defaultdict(dict) + self.nodes_with_agents_coming, self.edges_with_agents_coming = {}, defaultdict(dict) + self.nodes_with_malfunctions, self.edges_with_malfunctions = {}, defaultdict(dict) + self.nodes_with_departures, self.edges_with_departures = {}, defaultdict(dict) + + cdef int direction = 0 + + # Create some lookup tables that we can use later to figure out how far away the agents are from each other. + for agent in self.env.agents: + if agent.status == RailAgentStatus.READY_TO_DEPART and agent.initial_position: + for direction in range(4): + if (*agent.initial_position, direction) in self.graph: + self.nodes_with_departures[(*agent.initial_position, direction)] = 1 + + for start, _, start_direction, distance in self.edge_positions[ + (*agent.initial_position, direction)]: + self.edges_with_departures[(*start.position, start_direction)][agent.handle] = distance + + if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position: + agent_key = (*agent.position, agent.direction) + for direction in range(4): + # # Check the nodes + if (*agent.position, direction) in self.graph: + node_dict = self.nodes_with_agents_going if direction == agent.direction else self.nodes_with_agents_coming + node_dict[(*agent.position, direction)] = agent.speed_data['speed'] + + # if len(self.graph[agent_key].edges) > 1: + # exit_direction = get_direction(agent.direction, agent.speed_data['transition_action_on_cellexit']) + # if agent.speed_data['position_fraction'] == 0 or exit_direction not in self.graph[agent_key].edges: # Agent still has options + # self.nodes_with_agents_going[(*agent.position, direction)] = agent.speed_data['speed'] + # else: # Agent has already decided + # coming_direction = (exit_direction + 2) % 4 + # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going + # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] + # else: + # exit_direction = first(self.graph[agent_key].edges.keys()) + # coming_direction = (exit_direction + 2) % 4 + # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going + # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] + + # Check the edges + if agent_key in self.edge_positions: + exit_direction = first(self.get_possible_transitions(agent.position, agent.direction)) + coming_direction = (exit_direction + 2) % 4 + edge_dict = self.edges_with_agents_coming if direction == coming_direction else self.edges_with_agents_going + if direction == agent.direction or direction == coming_direction: + for start, _, start_direction, distance in self.edge_positions[ + (*agent.position, direction)]: + edge_dict[(*start.position, start_direction)][agent.handle] = ( + distance, agent.speed_data['speed']) + + # Check for malfunctions + if agent.malfunction_data['malfunction']: + if (*agent.position, direction) in self.graph: + self.nodes_with_malfunctions[(*agent.position, direction)] = agent.malfunction_data[ + 'malfunction'] + + for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: + self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ + (distance, agent.malfunction_data['malfunction']) + + return super().get_many(handles) + + # Compute the observation for a single agent + def get(self, int handle): + agent = self.env.agents[handle] + cdef set visited_cells = set() + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_position = agent.target + else: + return None + + # The root node contains information about the agent itself + cdef int direction = 0 + cdef int distance = 0 + root_tree_node = Node(0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent_position, agent.direction)], + 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed'], 0, + {x: negative_infinity for x in ACTIONS}) + + # Then we build out the tree by exploring from this node + cdef tuple key = (*agent_position, agent.direction) + cdef RailNode node = RailNode(tuple(), tuple(), 0) + cdef RailNode prev_node = RailNode(tuple(), tuple(), 0) + if key in self.graph: # If we're sitting on a junction, branch out immediately + node = self.graph[key] + if len(node.edges) > 1: # Major node + for direction in self.graph[key].edges.keys(): + root_tree_node.childs[get_action(agent.direction, direction)] = \ + self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) + else: # Minor node + direction = first(self.get_possible_transitions(node.position, agent.direction)) + root_tree_node.childs['F'] = self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) + + else: # Just create a single child in the forward direction + prev_node, _, direction, distance = first(self.edge_positions[key]) + root_tree_node.childs['F'] = self.get_tree_branch(agent, prev_node, direction, visited_cells, -distance, 1) + + self.env.dev_obs_dict[handle] = visited_cells + + return root_tree_node + + # Get the next tree node, starting from `node`, facing `orientation`, and moving in `direction`. + def get_tree_branch(self, agent, RailNode node, int direction, visited_cells, int total_distance, int depth): + visited_cells.add((*node.position, 0)) + next_node, distance = node.edges[direction] + + cdef int edge_length = 0 + cdef int max_malfunction_length = 0 + cdef int num_agents_same_direction = 0 + cdef int num_agents_other_direction = 0 + cdef int distance_to_minor_node = positive_infinity + cdef int distance_to_other_agent = positive_infinity + cdef int distance_to_own_target = positive_infinity + cdef int distance_to_other_target = positive_infinity + cdef float min_agent_speed = 1 + cdef int num_agent_departures = 0 + + cdef int orientation = 0 + cdef int dist = 0 + + cdef int tmp_dist = 0 + cdef int tmp1 = 0 # Speed/Malfunction Length + + cdef tuple key = tuple() + cdef tuple next_key = tuple() + + cdef list path = list() + + # Skip ahead until we get to a major node, logging any agents on the tracks along the way + while True: + path = self.edge_paths.get((node.position, direction), []) + orientation = path[-1][-1] if path else direction + dist = total_distance + edge_length + key = (*node.position, direction) + next_key = (*next_node.position, orientation) + + visited_cells.update(path) + visited_cells.add((*next_node.position, 0)) + + # Check for other agents on the junctions up ahead + if next_key in self.nodes_with_agents_going: + num_agents_same_direction += 1 + # distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) + min_agent_speed = min(min_agent_speed, self.nodes_with_agents_going[next_key]) + + if next_key in self.nodes_with_agents_coming: + num_agents_other_direction += 1 + distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) + + if next_key in self.nodes_with_departures: + num_agent_departures += 1 + if next_key in self.nodes_with_malfunctions: + max_malfunction_length = max(max_malfunction_length, self.nodes_with_malfunctions[next_key]) + + # Check for other agents along the tracks up ahead + for tmp_dist, tmp1 in self.edges_with_agents_going[key].values(): + if dist + tmp_dist > 0: + num_agents_same_direction += 1 + min_agent_speed = min(min_agent_speed, tmp1) + # distance_to_other_agent = min(distance_to_other_agent, edge_length + d) + + for tmp_dist, _ in self.edges_with_agents_coming[key].values(): + if dist + tmp_dist > 0: + num_agents_other_direction += 1 + distance_to_other_agent = min(distance_to_other_agent, edge_length + tmp_dist) + + for tmp_dist in self.edges_with_departures[key].values(): + if dist + tmp_dist > 0: + num_agent_departures += 1 + + for tmp_dist, tmp1 in self.edges_with_malfunctions[key].values(): + if dist + tmp_dist > 0: + max_malfunction_length = max(max_malfunction_length, tmp1) + + # Check for target nodes up ahead + if next_node.is_target: + if self.is_own_target(agent, next_node): + distance_to_own_target = min(distance_to_own_target, edge_length + distance) + else: + distance_to_other_target = min(distance_to_other_target, edge_length + distance) + + # Move on to the next node + node = next_node + edge_length += distance + + if len(node.edges) == 1 and not self.is_own_target(agent, node): # This is a minor node, keep exploring + direction, (next_node, distance) = first(node.edges.items()) + if not node.is_target: + distance_to_minor_node = min(distance_to_minor_node, edge_length) + else: + break + + # Create a new tree node and populate its children + cdef dict children = {} + cdef str x = '' + if depth < self.max_depth: + for x in ACTIONS: + children[x] = negative_infinity + if not self.is_own_target(agent, node): + for direction in node.edges.keys(): + children[get_action(orientation, direction)] = \ + self.get_tree_branch(agent, node, direction, visited_cells, total_distance + edge_length, + depth + 1) + + return Node(dist_own_target_encountered=total_distance + distance_to_own_target, + dist_other_target_encountered=total_distance + distance_to_other_target, + dist_other_agent_encountered=total_distance + distance_to_other_agent, + dist_potential_conflict=positive_infinity, + dist_unusable_switch=total_distance + distance_to_minor_node, + dist_to_next_branch=total_distance + edge_length, + dist_min_to_target=self.env.distance_map.get()[(agent.handle, *node.position, orientation)] or 0, + num_agents_same_direction=num_agents_same_direction, + num_agents_opposite_direction=num_agents_other_direction, + num_agents_malfunctioning=max_malfunction_length, + speed_min_fractional=min_agent_speed, + num_agents_ready_to_depart=num_agent_departures, + childs=children) + + # Helper functions + + def get_possible_transitions(self, tuple position, int direction): + return [i for i, allowed in enumerate(self.env.rail.get_transitions(*position, direction)) if allowed] + + def get_all_transitions(self, tuple position): + return tuple(tuple(i for i, allowed in enumerate(bits) if allowed == '1') + for bits in f'{self.env.rail.get_full_transitions(*position):019_b}'.split("_")) + + def is_junction(self, tuple position): + return any(map(_check_len1, self.get_all_transitions(position))) + + def is_target(self, tuple position): + return position in self.target_positions -try: - from .tree_observation import ACTIONS, negative_infinity, positive_infinity -except: - from tree_observation import ACTIONS, negative_infinity, positive_infinity + def is_own_target(self, agent, RailNode node): + return agent.target == node.position ZERO_NODE = torch.zeros((1, 11)) @@ -41,4 +402,4 @@ cpdef normalize_observation(tuple observations, int max_depth, shared_tensor, in shared_tensor[starting_index:starting_index + i] = torch.stack([torch.stack([torch.cat(dat, 0) for dat in tree if dat != []], -1) - for tree in data if tree != []], 0) \ No newline at end of file + for tree in data if tree != []], 0) diff --git a/src/train.py b/src/train.py index a9af96d..f1d2008 100644 --- a/src/train.py +++ b/src/train.py @@ -10,17 +10,18 @@ from flatland.envs.rail_env import RailEnv from pathos import multiprocessing +positive_infinity = int(1e5) +negative_infinity = -positive_infinity + try: from .agent import Agent as DQN_Agent, device, BATCH_SIZE from .normalize_output_data import wrap - from .tree_observation import TreeObservation, negative_infinity, positive_infinity - from .observation_utils import normalize_observation + from .observation_utils import normalize_observation, TreeObservation from .railway_utils import load_precomputed_railways, create_random_railways except: from agent import Agent as DQN_Agent, device, BATCH_SIZE from normalize_output_data import wrap - from tree_observation import TreeObservation, negative_infinity, positive_infinity - from observation_utils import normalize_observation + from observation_utils import normalize_observation, TreeObservation from railway_utils import load_precomputed_railways, create_random_railways project_root = Path(__file__).resolve().parent.parent @@ -226,7 +227,7 @@ def normalize(observation, target_tensor): current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) - print(f'\rBatch {episode:<5} - Episode {BATCH_SIZE*episode:<5}' + print(f'\rBatch {episode:<5} - Episode {BATCH_SIZE * episode:<5}' f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' diff --git a/src/tree_observation.py b/src/tree_observation.py index 1f3e528..e69de29 100644 --- a/src/tree_observation.py +++ b/src/tree_observation.py @@ -1,341 +0,0 @@ -from collections import defaultdict - -import torch -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.observations import TreeObsForRailEnv - -ACTIONS = ['L', 'F', 'R', 'B'] -Node = TreeObsForRailEnv.Node - -positive_infinity = (torch.ones(1) / torch.zeros(1))[0] -negative_infinity = -positive_infinity - - -def first(list): - return next(iter(list)) - - -def get_action(orientation, direction): - return ACTIONS[(direction - orientation + 1 + 4) % 4] - - -def get_direction(orientation, action): - if action == 1: - return (orientation + 4 - 1) % 4 - elif action == 3: - return (orientation + 1) % 4 - else: - return orientation - - -class RailNode: - def __init__(self, position, edge_directions, is_target): - self.edges = {} - self.position = position - self.edge_directions = edge_directions - self.is_target = is_target - - def __repr__(self): - return f'RailNode({self.position}, {len(self.edges)})' - - -class TreeObservation(ObservationBuilder): - def __init__(self, max_depth): - super().__init__() - self.max_depth = max_depth - self.observation_dim = 11 - - # Create a graph representation of the current rail network - - def reset(self): - self.target_positions = {agent.target: 1 for agent in self.env.agents} - self.edge_positions = defaultdict(list) # (cell.position, direction) -> [(start, end, direction, distance)] - self.edge_paths = defaultdict(list) # (node.position, direction) -> [(cell.position, direction)] - - # First, we find a node by starting at one of the agents and following the rails until we reach a junction - agent = first(self.env.agents) - position = agent.initial_position - direction = agent.direction - while True: - try: - out = self.is_junction(position) or self.is_target(position) - except IndexError: - break - if not out: - break - direction = first(self.get_possible_transitions(position, direction)) - position = get_new_position(position, direction) - - # Now we create a graph representation of the rail network, starting from this node - transitions = self.get_all_transitions(position) - root_nodes = {t: RailNode(position, t, self.is_target(position)) for t in transitions if t} - self.graph = {(*position, d): root_nodes[t] for d, t in enumerate(transitions) if t} - - for transitions, node in root_nodes.items(): - for direction in transitions: - self.explore_branch(node, get_new_position(position, direction), direction) - - def explore_branch(self, node, position, direction): - original_direction = direction - edge_positions = {} - distance = 1 - - # Explore until we find a junction - while not self.is_junction(position) and not self.is_target(position): - next_direction = first(self.get_possible_transitions(position, direction)) - edge_positions[(*position, direction)] = (distance, next_direction) - position = get_new_position(position, next_direction) - direction = next_direction - distance += 1 - - # Create any nodes that aren't in the graph yet - transitions = self.get_all_transitions(position) - nodes = {t: RailNode(position, t, self.is_target(position)) - for d, t in enumerate(transitions) - if t and (*position, d) not in self.graph} - - for d, t in enumerate(transitions): - if t in nodes: - self.graph[(*position, d)] = nodes[t] - - # Connect the previous node to the next one, and update self.edge_positions - next_node = self.graph[(*position, direction)] - node.edges[original_direction] = (next_node, distance) - for key, (distance, next_direction) in edge_positions.items(): - self.edge_positions[key].append((node, next_node, original_direction, distance)) - self.edge_paths[node.position, original_direction].append((*key, next_direction)) - - # Call ourselves recursively since we're exploring depth-first - for transitions, node in nodes.items(): - for direction in transitions: - self.explore_branch(node, get_new_position(position, direction), direction) - - # Create a tree observation for each agent, based on the graph we created earlier - - def get_many(self, handles=[]): - self.nodes_with_agents_going, self.edges_with_agents_going = {}, defaultdict(dict) - self.nodes_with_agents_coming, self.edges_with_agents_coming = {}, defaultdict(dict) - self.nodes_with_malfunctions, self.edges_with_malfunctions = {}, defaultdict(dict) - self.nodes_with_departures, self.edges_with_departures = {}, defaultdict(dict) - - # Create some lookup tables that we can use later to figure out how far away the agents are from each other. - for agent in self.env.agents: - if agent.status == RailAgentStatus.READY_TO_DEPART and agent.initial_position: - for direction in range(4): - if (*agent.initial_position, direction) in self.graph: - self.nodes_with_departures[(*agent.initial_position, direction)] = 1 - - for start, _, start_direction, distance in self.edge_positions[ - (*agent.initial_position, direction)]: - self.edges_with_departures[(*start.position, start_direction)][agent.handle] = distance - - if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position: - agent_key = (*agent.position, agent.direction) - for direction in range(4): - # # Check the nodes - if (*agent.position, direction) in self.graph: - node_dict = self.nodes_with_agents_going if direction == agent.direction else self.nodes_with_agents_coming - node_dict[(*agent.position, direction)] = agent.speed_data['speed'] - - # if len(self.graph[agent_key].edges) > 1: - # exit_direction = get_direction(agent.direction, agent.speed_data['transition_action_on_cellexit']) - # if agent.speed_data['position_fraction'] == 0 or exit_direction not in self.graph[agent_key].edges: # Agent still has options - # self.nodes_with_agents_going[(*agent.position, direction)] = agent.speed_data['speed'] - # else: # Agent has already decided - # coming_direction = (exit_direction + 2) % 4 - # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going - # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] - # else: - # exit_direction = first(self.graph[agent_key].edges.keys()) - # coming_direction = (exit_direction + 2) % 4 - # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going - # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] - - # Check the edges - if agent_key in self.edge_positions: - exit_direction = first(self.get_possible_transitions(agent.position, agent.direction)) - coming_direction = (exit_direction + 2) % 4 - edge_dict = self.edges_with_agents_coming if direction == coming_direction else self.edges_with_agents_going - if direction == agent.direction or direction == coming_direction: - for start, _, start_direction, distance in self.edge_positions[ - (*agent.position, direction)]: - edge_distance = distance if direction == agent.direction else \ - start.edges[start_direction][1] - distance - edge_dict[(*start.position, start_direction)][agent.handle] = ( - distance, agent.speed_data['speed']) - - # Check for malfunctions - if agent.malfunction_data['malfunction']: - if (*agent.position, direction) in self.graph: - self.nodes_with_malfunctions[(*agent.position, direction)] = agent.malfunction_data[ - 'malfunction'] - - for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: - self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ - (distance, agent.malfunction_data['malfunction']) - - return super().get_many(handles) - - # Compute the observation for a single agent - def get(self, handle): - agent = self.env.agents[handle] - visited_cells = set() - - if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: - agent_position = agent.position - elif agent.status == RailAgentStatus.DONE: - agent_position = agent.target - else: - return None - - # The root node contains information about the agent itself - children = {x: negative_infinity for x in ACTIONS} - dist_min_to_target = self.env.distance_map.get()[(handle, *agent_position, agent.direction)] - agent_malfunctioning, agent_speed = agent.malfunction_data['malfunction'], agent.speed_data['speed'] - root_tree_node = Node(0, 0, 0, 0, 0, 0, dist_min_to_target, 0, 0, agent_malfunctioning, agent_speed, 0, - children) - - # Then we build out the tree by exploring from this node - key = (*agent_position, agent.direction) - if key in self.graph: # If we're sitting on a junction, branch out immediately - node = self.graph[key] - if len(node.edges) > 1: # Major node - for direction in self.graph[key].edges.keys(): - root_tree_node.childs[get_action(agent.direction, direction)] = \ - self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) - else: # Minor node - direction = first(self.get_possible_transitions(node.position, agent.direction)) - root_tree_node.childs['F'] = self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) - - else: # Just create a single child in the forward direction - prev_node, next_node, direction, distance = first(self.edge_positions[key]) - root_tree_node.childs['F'] = self.get_tree_branch(agent, prev_node, direction, visited_cells, -distance, 1) - - self.env.dev_obs_dict[handle] = visited_cells - - return root_tree_node - - # Get the next tree node, starting from `node`, facing `orientation`, and moving in `direction`. - def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, depth): - visited_cells.add((*node.position, 0)) - next_node, distance = node.edges[direction] - original_position = node.position - - targets, agents, minor_nodes = [], [], [] - edge_length, max_malfunction_length = 0, 0 - num_agents_same_direction, num_agents_other_direction = 0, 0 - distance_to_minor_node, distance_to_other_agent = positive_infinity, positive_infinity - distance_to_own_target, distance_to_other_target = positive_infinity, positive_infinity - min_agent_speed, num_agent_departures = 1.0, 0 - - # Skip ahead until we get to a major node, logging any agents on the tracks along the way - while True: - path = self.edge_paths.get((node.position, direction), []) - orientation = path[-1][-1] if path else direction - dist = total_distance + edge_length - key = (*node.position, direction) - next_key = (*next_node.position, orientation) - - visited_cells.update(path) - visited_cells.add((*next_node.position, 0)) - - # Check for other agents on the junctions up ahead - if next_key in self.nodes_with_agents_going: - num_agents_same_direction += 1 - # distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) - min_agent_speed = min(min_agent_speed, self.nodes_with_agents_going[next_key]) - - if next_key in self.nodes_with_agents_coming: - num_agents_other_direction += 1 - distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) - - if next_key in self.nodes_with_departures: - num_agent_departures += 1 - if next_key in self.nodes_with_malfunctions: - max_malfunction_length = max(max_malfunction_length, self.nodes_with_malfunctions[next_key]) - - # Check for other agents along the tracks up ahead - for d, s in self.edges_with_agents_going[key].values(): - if dist + d > 0: - num_agents_same_direction += 1 - min_agent_speed = min(min_agent_speed, s) - # distance_to_other_agent = min(distance_to_other_agent, edge_length + d) - - for d, _ in self.edges_with_agents_coming[key].values(): - if dist + d > 0: - num_agents_other_direction += 1 - distance_to_other_agent = min(distance_to_other_agent, edge_length + d) - - for d in self.edges_with_departures[key].values(): - if dist + d > 0: - num_agent_departures += 1 - - for d, t in self.edges_with_malfunctions[key].values(): - if dist + d > 0: - max_malfunction_length = max(max_malfunction_length, t) - - # Check for target nodes up ahead - if next_node.is_target: - if self.is_own_target(agent, next_node): - distance_to_own_target = min(distance_to_own_target, edge_length + distance) - else: - distance_to_other_target = min(distance_to_other_target, edge_length + distance) - - # Move on to the next node - node = next_node - edge_length += distance - - if len(node.edges) == 1 and not self.is_own_target(agent, node): # This is a minor node, keep exploring - direction, (next_node, distance) = first(node.edges.items()) - if not node.is_target: - distance_to_minor_node = min(distance_to_minor_node, edge_length) - else: - break - - # Create a new tree node and populate its children - if depth < self.max_depth: - children = {x: negative_infinity for x in ACTIONS} - if not self.is_own_target(agent, node): - for direction in node.edges.keys(): - children[get_action(orientation, direction)] = \ - self.get_tree_branch(agent, node, direction, visited_cells, total_distance + edge_length, - depth + 1) - - else: - children = {} - - return Node(dist_own_target_encountered=total_distance + distance_to_own_target, - dist_other_target_encountered=total_distance + distance_to_other_target, - dist_other_agent_encountered=total_distance + distance_to_other_agent, - dist_potential_conflict=positive_infinity, - dist_unusable_switch=total_distance + distance_to_minor_node, - dist_to_next_branch=total_distance + edge_length, - dist_min_to_target=self.env.distance_map.get()[(agent.handle, *node.position, orientation)] or 0, - num_agents_same_direction=num_agents_same_direction, - num_agents_opposite_direction=num_agents_other_direction, - num_agents_malfunctioning=max_malfunction_length, - speed_min_fractional=min_agent_speed, - num_agents_ready_to_depart=num_agent_departures, - childs=children) - - # Helper functions - - def get_possible_transitions(self, position, direction): - return [i for i, allowed in enumerate(self.env.rail.get_transitions(*position, direction)) if allowed] - - def get_all_transitions(self, position): - bit_groups = f'{self.env.rail.get_full_transitions(*position):019_b}'.split("_") - return [tuple(i for i, allowed in enumerate(bits) if allowed == '1') for bits in bit_groups] - - def is_junction(self, position): - return any(len(transitions) > 1 for transitions in self.get_all_transitions(position)) - - def is_target(self, position): - return position in self.target_positions - - def is_own_target(self, agent, node): - return agent.target == node.position From 949142996d83cbc5bacc14cf4b2279d1d0957370 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 16:19:03 +0200 Subject: [PATCH 18/75] perf: reduce model size, use single qn --- src/agent.py | 53 +++++++++++++++++++++++++++++----------------------- src/model.py | 15 ++++++++++----- src/train.py | 28 +++++++++++++-------------- 3 files changed, 53 insertions(+), 43 deletions(-) diff --git a/src/agent.py b/src/agent.py index 897fdfe..9a70d7a 100644 --- a/src/agent.py +++ b/src/agent.py @@ -17,7 +17,8 @@ GAMMA = 0.998 TAU = 1e-3 LR = 2e-4 -UPDATE_EVERY = 16 +UPDATE_EVERY = 1 +DOUBLE_DQN = False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -88,25 +89,27 @@ def multi_act(self, state, eps=0.): # Record the results of the agent's action and update the model def step(self, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): - if len(self.stack) >= UPDATE_EVERY - 1: - action = self.memory.stack(self.stack[1]).to(device) - reward = self.memory.stack([1 if ad - else (c - 5 if collision else step_reward) - for ad, c in zip(self.stack[2], self.stack[4])]).to(device) - dones = self.memory.stack(self.stack[3]).to(device).float() - state = state.to(device) - next_state = next_state.to(device) + self.stack[0].append(state) + self.stack[1].append(action) + self.stack[2].append(next_state) + self.stack[3].append([[v or episode_done for k, v in a.items() + if not hasattr(k, 'startswith') + or not k.startswith('_')] for a in agent_done]) + self.stack[4].append(collision) + + if len(self.stack) >= UPDATE_EVERY: + action = torch.tensor(self.stack[1]).flatten(0, 1).to(device) + reward = torch.tensor([[[1 if ad + else (-5 if c else step_reward) for ad, c in zip(ad_batch, c_batch)] + for ad_batch, c_batch in zip(ad_step, c_step)] + for ad_step, c_step in zip(self.stack[3], self.stack[4])]).flatten(0, 1).to(device) + dones = torch.tensor(self.stack[3]).flatten(0, 1).to(device).float() + state = torch.cat(self.stack[0], 0).to(device) + next_state = torch.cat(self.stack[2], 0).to(device) state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) next_state = torch.cat([next_state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + self.stack = [[] for _ in range(5)] self.learn(state, action, reward, next_state, dones) - else: - self.stack[0].append(state) - self.stack[1].append(action) - self.stack[2].append(next_state) - self.stack[3].append([[v or episode_done for k, v in a.items() - if not hasattr(k, 'startswith') - or not k.startswith('_')] for a in agent_done]) - self.stack[4].append(collision) def learn(self, states, actions, rewards, next_states, dones): self.qnetwork_local.train() @@ -117,15 +120,19 @@ def learn(self, states, actions, rewards, next_states, dones): # Get expected Q values from local model Q_expected = self.qnetwork_local(states.squeeze(1)) - Q_expected = Q_expected.gather(1, actions.unsqueeze(1)) - Q_best_action = self.qnetwork_local(next_states.squeeze(1)).argmax(1) - Q_targets_next = self.qnetwork_target(next_states.squeeze(1)).gather(1, Q_best_action.unsqueeze(1)) + if DOUBLE_DQN: + Q_expected = Q_expected.gather(1, actions.unsqueeze(1)) + Q_best_action = self.qnetwork_local(next_states.squeeze(1)).argmax(1, keepdim=True) + Q_targets_next = self.qnetwork_target(next_states.squeeze(1)).gather(1, Q_best_action) - # Compute loss and perform a gradient step - self.optimizer.zero_grad() - loss = (GAMMA * Q_targets_next * (1 - dones.unsqueeze(-2)) - Q_expected - rewards.unsqueeze(-1)).square().mean() + # Compute loss and perform a gradient step + loss = (GAMMA * Q_targets_next * (1 - dones.unsqueeze(-2)) + - Q_expected - rewards.unsqueeze(-1)).square().mean() + else: + loss = (Q_expected.gather(1, actions.unsqueeze(1)) * rewards.unsqueeze(-1)).clamp(min=0).mean() loss.backward() self.optimizer.step() + self.optimizer.zero_grad() # Update the target network parameters to `tau * local.parameters() + (1 - tau) * target.parameters()` for target_param, local_param in zip(self.qnetwork_target.parameters(), self.qnetwork_local.parameters()): diff --git a/src/model.py b/src/model.py index c8b91bc..b26360f 100644 --- a/src/model.py +++ b/src/model.py @@ -283,8 +283,13 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): - model = torch.nn.Sequential(WeightDropConv(state_size + 1, 11 * hidden_factor, bias=False), - torch.nn.BatchNorm1d(11 * hidden_factor), - Mish(), - WeightDropConv(11 * hidden_factor, action_size)) - return model + # model = torch.nn.Sequential(WeightDropConv(state_size + 1, 11 * hidden_factor, bias=False), + # torch.nn.BatchNorm1d(11 * hidden_factor), + # Mish(), + # WeightDropConv(11 * hidden_factor, action_size)) + model = torch.nn.Sequential(torch.nn.Conv1d(state_size + 1, 20, 1, bias=False), + torch.nn.BatchNorm1d(20), + torch.nn.ReLU6(), + torch.nn.Conv1d(20, action_size, 1)) + print(model) + return torch.jit.script(model) diff --git a/src/train.py b/src/train.py index f1d2008..a7971b1 100644 --- a/src/train.py +++ b/src/train.py @@ -84,15 +84,16 @@ rail_generator, schedule_generator = create_random_railways(project_root) # Create the Flatland environment -env = RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, - rail_generator=rail_generator, - schedule_generator=schedule_generator, - malfunction_generator_and_process_data=malfunction_from_params(MalfunctionParameters(1 / 8000, 15, 50)), - obs_builder_object=(GlobalObsForRailEnv() - if flags.global_environment - else TreeObservation(max_depth=flags.tree_depth)) - ) -environments = [copy.copy(env) for _ in range(BATCH_SIZE)] +environments = [RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + malfunction_generator_and_process_data=malfunction_from_params( + MalfunctionParameters(1 / 8000, 15, 50)), + obs_builder_object=(GlobalObsForRailEnv() + if flags.global_environment + else TreeObservation(max_depth=flags.tree_depth)) + ) for _ in range(BATCH_SIZE)] +env = environments[0] # After training we want to render the results so we also load a renderer @@ -151,10 +152,7 @@ def normalize(observation, target_tensor): # Main training loop for episode in range(start + 1, flags.num_episodes + 1): agent.reset() - obs, info = env.reset(True, True) - environments = [copy.copy(env) for _ in range(BATCH_SIZE)] - obs = tuple(copy.deepcopy(obs) for _ in range(BATCH_SIZE)) - info = [copy.deepcopy(info) for _ in range(BATCH_SIZE)] + obs, info = zip(*[env.reset(True, True) for env in environments]) score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) @@ -168,7 +166,7 @@ def normalize(observation, target_tensor): update_values = [[False] * agent_count for _ in range(BATCH_SIZE)] action_dict = [{} for _ in range(BATCH_SIZE)] - if all(any(inf['action_required']) for inf in info): + if any(any(inf['action_required']) for inf in info): ret_action = agent.multi_act(agent_obs.flatten(1, 2), eps=eps) else: ret_action = update_values @@ -227,7 +225,7 @@ def normalize(observation, target_tensor): current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) - print(f'\rBatch {episode:<5} - Episode {BATCH_SIZE * episode:<5}' + print(f'\rBatch {episode:<4} - Episode {BATCH_SIZE * episode:<6}' f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' From 0418a674b37517ff5d2981d6b054dd302ce47739 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 16:25:21 +0200 Subject: [PATCH 19/75] perf: cythonize rail_env --- src/cythonize.sh | 19 +- src/rail_env.pyx | 862 +++++++++++++++++++++++++++++++++++++++++++++++ src/train.py | 3 +- 3 files changed, 877 insertions(+), 7 deletions(-) create mode 100644 src/rail_env.pyx diff --git a/src/cythonize.sh b/src/cythonize.sh index e08d63d..2f38ba2 100644 --- a/src/cythonize.sh +++ b/src/cythonize.sh @@ -1,6 +1,13 @@ -cython observation_utils.pyx -3 -Wextra -D -cmd="gcc-7 observation_utils.c `python3-config --cflags --ldflags --includes --libs` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o observation_utils.so" -echo "Executing $cmd" -$cmd -echo "Testing compilation.." -python3 -c "import observation_utils" +function compile { + file=${1} + cython "$file.pyx" -3 -Wextra -D + cmd="gcc-7 $file.c `python3-config --cflags --ldflags --includes --libs` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" + echo "Executing $cmd" + $cmd + echo "Testing compilation.." + python3 -c "import $file" + echo +} + +compile observation_utils +compile rail_env \ No newline at end of file diff --git a/src/rail_env.pyx b/src/rail_env.pyx new file mode 100644 index 0000000..6dc67bb --- /dev/null +++ b/src/rail_env.pyx @@ -0,0 +1,862 @@ +""" +Definition of the RailEnv environment. +""" +import random +# TODO: _ this is a global method --> utils or remove later +from enum import IntEnum +from typing import List, NamedTuple, Optional, Dict + +import msgpack +import msgpack_numpy as m +import numpy as np +from gym.utils import seeding +from msgpack import Packer + +from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import IntVector2D +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.distance_map import DistanceMap + +# Need to use circular imports for persistence. +from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import rail_generators as rail_gen +from flatland.envs import schedule_generators as sched_gen +from flatland.envs import persistence + +# Direct import of objects / classes does not work with circular imports. +# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData +# from flatland.envs.observations import GlobalObsForRailEnv +# from flatland.envs.rail_generators import random_rail_generator, RailGenerator +# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator + +from flatland.envs.observations import GlobalObsForRailEnv + + + +import pickle + +m.patch() + + +class RailEnvActions(IntEnum): + DO_NOTHING = 0 # implies change of direction in a dead-end! + MOVE_LEFT = 1 + MOVE_FORWARD = 2 + MOVE_RIGHT = 3 + STOP_MOVING = 4 + + @staticmethod + def to_char(a: int): + return { + 0: 'B', + 1: 'L', + 2: 'F', + 3: 'R', + 4: 'S', + }[a] + + +RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) +RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos), + ('next_direction', Grid4TransitionsEnum)]) + + +class RailEnv(Environment): + """ + RailEnv environment class. + + RailEnv is an environment inspired by a (simplified version of) a rail + network, in which agents (trains) have to navigate to their target + locations in the shortest time possible, while at the same time cooperating + to avoid bottlenecks. + + The valid actions in the environment are: + + - 0: do nothing (continue moving or stay still) + - 1: turn left at switch and move to the next cell; if the agent was not moving, movement is started + - 2: move to the next cell in front of the agent; if the agent was not moving, movement is started + - 3: turn right at switch and move to the next cell; if the agent was not moving, movement is started + - 4: stop moving + + Moving forward in a dead-end cell makes the agent turn 180 degrees and step + to the cell it came from. + + + The actions of the agents are executed in order of their handle to prevent + deadlocks and to allow them to learn relative priorities. + + Reward Function: + + It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement + of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0. + + alpha = 1 + beta = 1 + Reward function parameters: + + - invalid_action_penalty = 0 + - step_penalty = -alpha + - global_reward = beta + - epsilon = avoid rounding errors + - stop_penalty = 0 # penalty for stopping a moving agent + - start_penalty = 0 # penalty for starting a stopped agent + + Stochastic malfunctioning of trains: + Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid + action or cell is selected. + + Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a + poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep + complexity managable. + + TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init(). + For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility. + + """ + alpha = 1.0 + beta = 1.0 + # Epsilon to avoid rounding errors + epsilon = 0.01 + invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty + step_penalty = -1 * alpha + global_reward = 1 * beta + stop_penalty = 0 # penalty for stopping a moving agent + start_penalty = 0 # penalty for starting a stopped agent + + def __init__(self, + width, + height, + rail_generator = None, + schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), + number_of_agents=1, + obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), + malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), + remove_agents_at_target=True, + random_seed=1, + record_steps=False + ): + """ + Environment init. + + Parameters + ---------- + rail_generator : function + The rail_generator function is a function that takes the width, + height and agents handles of a rail environment, along with the number of times + the env has been reset, and returns a GridTransitionMap object and a list of + starting positions, targets, and initial orientations for agent handle. + The rail_generator can pass a distance map in the hints or information for specific schedule_generators. + Implementations can be found in flatland/envs/rail_generators.py + schedule_generator : function + The schedule_generator function is a function that takes the grid, the number of agents and optional hints + and returns a list of starting positions, targets, initial orientations and speed for all agent handles. + Implementations can be found in flatland/envs/schedule_generators.py + width : int + The width of the rail map. Potentially in the future, + a range of widths to sample from. + height : int + The height of the rail map. Potentially in the future, + a range of heights to sample from. + number_of_agents : int + Number of agents to spawn on the map. Potentially in the future, + a range of number of agents to sample from. + obs_builder_object: ObservationBuilder object + ObservationBuilder-derived object that takes builds observation + vectors for each agent. + remove_agents_at_target : bool + If remove_agents_at_target is set to true then the agents will be removed by placing to + RailEnv.DEPOT_POSITION when the agent has reach it's target position. + random_seed : int or None + if None, then its ignored, else the random generators are seeded with this number to ensure + that stochastic operations are replicable across multiple operations + """ + super().__init__() + + if malfunction_generator_and_process_data is None: + malfunction_generator_and_process_data = mal_gen.no_malfunction_generator() + self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data + #self.rail_generator: RailGenerator = rail_generator + if rail_generator is None: + rail_generator = rail_gen.random_rail_generator() + self.rail_generator = rail_generator + #self.schedule_generator: ScheduleGenerator = schedule_generator + if schedule_generator is None: + schedule_generator = sched_gen.random_schedule_generator() + self.schedule_generator = schedule_generator + + self.rail: Optional[GridTransitionMap] = None + self.width = width + self.height = height + + self.remove_agents_at_target = remove_agents_at_target + + self.rewards = [0] * number_of_agents + self.done = False + self.obs_builder = obs_builder_object + self.obs_builder.set_env(self) + + self._max_episode_steps: Optional[int] = None + self._elapsed_steps = 0 + + self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) + + self.obs_dict = {} + self.rewards_dict = {} + self.dev_obs_dict = {} + self.dev_pred_dict = {} + + self.agents: List[EnvAgent] = [] + self.number_of_agents = number_of_agents + self.num_resets = 0 + self.distance_map = DistanceMap(self.agents, self.height, self.width) + + self.action_space = [5] + + self._seed() + self._seed() + self.random_seed = random_seed + if self.random_seed: + self._seed(seed=random_seed) + + self.valid_positions = None + + # global numpy array of agents position, True means that there is an agent at that cell + self.agent_positions: np.ndarray = np.full((height, width), False) + + # save episode timesteps ie agent positions, orientations. (not yet actions / observations) + self.record_steps = record_steps # whether to save timesteps + # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps] + self.cur_episode = [] + self.list_actions = [] # save actions in here + + def _seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + random.seed(seed) + return [seed] + + # no more agent_handles + def get_agent_handles(self): + return range(self.get_num_agents()) + + def get_num_agents(self) -> int: + return len(self.agents) + + def add_agent(self, agent): + """ Add static info for a single agent. + Returns the index of the new agent. + """ + self.agents.append(agent) + return len(self.agents) - 1 + + def set_agent_active(self, agent: EnvAgent): + if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + self._set_agent_to_initial_position(agent, agent.initial_position) + + def reset_agents(self): + """ Reset the agents to their starting positions + """ + for agent in self.agents: + agent.reset() + self.active_agents = [i for i in range(len(self.agents))] + + + + def action_required(self, agent): + """ + Check if an agent needs to provide an action + + Parameters + ---------- + agent: RailEnvAgent + Agent we want to check + + Returns + ------- + True: Agent needs to provide an action + False: Agent cannot provide an action + """ + return (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) + + def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, + random_seed: bool = None) -> (Dict, Dict): + """ + reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) + + The method resets the rail environment + + Parameters + ---------- + regenerate_rail : bool, optional + regenerate the rails + regenerate_schedule : bool, optional + regenerate the schedule and the static agents + activate_agents : bool, optional + activate the agents + random_seed : bool, optional + random seed for environment + + Returns + ------- + observation_dict: Dict + Dictionary with an observation for each agent + info_dict: Dict with agent specific information + + """ + + if random_seed: + self._seed(random_seed) + + optionals = {} + if regenerate_rail or self.rail is None: + rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets, + self.np_random) + + self.rail = rail + self.height, self.width = self.rail.grid.shape + + # Do a new set_env call on the obs_builder to ensure + # that obs_builder specific instantiations are made according to the + # specifications of the current environment : like width, height, etc + self.obs_builder.set_env(self) + + if optionals and 'distance_map' in optionals: + self.distance_map.set(optionals['distance_map']) + + + + if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: + agents_hints = None + if optionals and 'agents_hints' in optionals: + agents_hints = optionals['agents_hints'] + + schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets, + self.np_random) + self.agents = EnvAgent.from_schedule(schedule) + + # Get max number of allowed time steps from schedule generator + # Look at the specific schedule generator used to see where this number comes from + self._max_episode_steps = schedule.max_episode_steps + + self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 + + # Reset agents to initial + self.reset_agents() + + for agent in self.agents: + # Induce malfunctions + if activate_agents: + self.set_agent_active(agent) + + self._break_agent(agent) + + if agent.malfunction_data["malfunction"] > 0: + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING + + # Fix agents that finished their malfunction + self._fix_agent_after_malfunction(agent) + + self.num_resets += 1 + self._elapsed_steps = 0 + + # TODO perhaps dones should be part of each agent. + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + # Reset the state of the observation builder with the new environment + self.obs_builder.reset() + self.distance_map.reset(self.agents, self.rail) + + # Reset the malfunction generator + self.malfunction_generator(reset=True) + + # Empty the episode store of agent positions + self.cur_episode = [] + + info_dict: Dict = { + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) + }, + 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} + } + # Return the new observation vectors for each agent + observation_dict: Dict = self._get_observations() + return observation_dict, info_dict + + def _fix_agent_after_malfunction(self, agent: EnvAgent): + """ + Updates agent malfunction variables and fixes broken agents + + Parameters + ---------- + agent + """ + + # Ignore agents that are OK + if self._is_agent_ok(agent): + return + + # Reduce number of malfunction steps left + if agent.malfunction_data['malfunction'] > 1: + agent.malfunction_data['malfunction'] -= 1 + return + + # Restart agents at the end of their malfunction + agent.malfunction_data['malfunction'] -= 1 + if 'moving_before_malfunction' in agent.malfunction_data: + agent.moving = agent.malfunction_data['moving_before_malfunction'] + return + + def _break_agent(self, agent: EnvAgent): + """ + Malfunction generator that breaks agents at a given rate. + + Parameters + ---------- + agent + + """ + + malfunction: Malfunction = self.malfunction_generator(agent, self.np_random) + if malfunction.num_broken_steps > 0: + agent.malfunction_data['malfunction'] = malfunction.num_broken_steps + agent.malfunction_data['moving_before_malfunction'] = agent.moving + agent.malfunction_data['nr_malfunctions'] += 1 + + return + + def step(self, action_dict_: Dict[int, RailEnvActions]): + """ + Updates rewards for the agents at a step. + + Parameters + ---------- + action_dict_ : Dict[int,RailEnvActions] + + """ + self._elapsed_steps += 1 + + # If we're done, set reward and info_dict and step() is done. + if self.dones["__all__"]: + self.rewards_dict = {} + info_dict = { + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + for i_agent, agent in enumerate(self.agents): + self.rewards_dict[i_agent] = self.global_reward + info_dict["action_required"][i_agent] = False + info_dict["malfunction"][i_agent] = 0 + info_dict["speed"][i_agent] = 0 + info_dict["status"][i_agent] = agent.status + + return self._get_observations(), self.rewards_dict, self.dones, info_dict + + # Reset the step rewards + self.rewards_dict = dict() + info_dict = { + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + have_all_agents_ended = True # boolean flag to check if all agents are done + + for i_agent, agent in enumerate(self.agents): + # Reset the step rewards + self.rewards_dict[i_agent] = 0 + + # Induce malfunction before we do a step, thus a broken agent can't move in this step + self._break_agent(agent) + + # Perform step on the agent + self._step_agent(i_agent, action_dict_.get(i_agent)) + + # manage the boolean flag to check if all agents are indeed done (or done_removed) + have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) + + # Build info dict + info_dict["action_required"][i_agent] = self.action_required(agent) + info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] + info_dict["speed"][i_agent] = agent.speed_data['speed'] + info_dict["status"][i_agent] = agent.status + + # Fix agents that finished their malfunction such that they can perform an action in the next step + self._fix_agent_after_malfunction(agent) + + # Check for end of episode + set global reward to all rewards! + if have_all_agents_ended: + self.dones["__all__"] = True + self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} + if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): + self.dones["__all__"] = True + for i_agent in range(self.get_num_agents()): + self.dones[i_agent] = True + if self.record_steps: + self.record_timestep(action_dict_) + + return self._get_observations(), self.rewards_dict, self.dones, info_dict + + def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None): + """ + Performs a step and step, start and stop penalty on a single agent in the following sub steps: + - malfunction + - action handling if at the beginning of cell + - movement + + Parameters + ---------- + i_agent : int + action_dict_ : Dict[int,RailEnvActions] + + """ + agent = self.agents[i_agent] + if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed... + return + + # agent gets active by a MOVE_* action and if c + if agent.status == RailAgentStatus.READY_TO_DEPART: + if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, + RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + self._set_agent_to_initial_position(agent, agent.initial_position) + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + else: + # TODO: Here we need to check for the departure time in future releases with full schedules + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + + agent.old_direction = agent.direction + agent.old_position = agent.position + + # if agent is broken, actions are ignored and agent does not move. + # full step penalty in this case + if agent.malfunction_data['malfunction'] > 0: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + + # Is the agent at the beginning of the cell? Then, it can take an action. + # As long as the agent is malfunctioning or stopped at the beginning of the cell, + # different actions may be taken! + if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): + # No action has been supplied for this agent -> set DO_NOTHING as default + if action is None: + action = RailEnvActions.DO_NOTHING + + if action < 0 or action > len(RailEnvActions): + print('ERROR: illegal action=', action, + 'for agent with index=', i_agent, + '"DO NOTHING" will be executed instead') + action = RailEnvActions.DO_NOTHING + + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action = RailEnvActions.MOVE_FORWARD + + if action == RailEnvActions.STOP_MOVING and agent.moving: + # Only allow halting an agent on entering new cells. + agent.moving = False + self.rewards_dict[i_agent] += self.stop_penalty + + if not agent.moving and not ( + action == RailEnvActions.DO_NOTHING or + action == RailEnvActions.STOP_MOVING): + # Allow agent to start with any forward or direction action + agent.moving = True + self.rewards_dict[i_agent] += self.start_penalty + + # Store the action if action is moving + # If not moving, the action will be stored when the agent starts moving again. + if agent.moving: + _action_stored = False + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(action, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = action + _action_stored = True + else: + # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, + # try to keep moving forward! + if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + _action_stored = True + + if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False + + # Now perform a movement. + # If agent.moving, increment the position_fraction by the speed of the agent + # If the new position fraction is >= 1, reset to 0, and perform the stored + # transition_action_on_cellexit if the cell is free. + if agent.moving: + agent.speed_data['position_fraction'] += agent.speed_data['speed'] + if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0, + rtol=1e-03): + # Perform stored action to transition to the next cell as soon as cell is free + # Notice that we've already checked new_cell_valid and transition valid when we stored the action, + # so we only have to check cell_free now! + + # cell and transition validity was checked when we stored transition_action_on_cellexit! + cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( + agent.speed_data['transition_action_on_cellexit'], agent) + + # N.B. validity of new_cell and transition should have been verified before the action was stored! + assert new_cell_valid + assert transition_valid + if cell_free: + self._move_agent_to_new_position(agent, new_position) + agent.direction = new_direction + agent.speed_data['position_fraction'] = 0.0 + + # has the agent reached its target? + if np.equal(agent.position, agent.target).all(): + agent.status = RailAgentStatus.DONE + self.dones[i_agent] = True + self.active_agents.remove(i_agent) + agent.moving = False + self._remove_agent_from_scene(agent) + else: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + else: + # step penalty if not moving (stopped now or before) + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + + def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Sets the agent to its initial position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.position] = agent.handle + + def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Move the agent to the a new position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.old_position] = -1 + self.agent_positions[agent.position] = agent.handle + + def _remove_agent_from_scene(self, agent: EnvAgent): + """ + Remove the agent from the scene. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + """ + self.agent_positions[agent.position] = -1 + if self.remove_agents_at_target: + agent.position = None + agent.status = RailAgentStatus.DONE_REMOVED + + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): + """ + + Parameters + ---------- + action : RailEnvActions + agent : EnvAgent + + Returns + ------- + bool + Is it a legal move? + 1) transition allows the new_direction in the cell, + 2) the new cell is not empty (case 0), + 3) the cell is free, i.e., no agent is currently in that cell + + + """ + # compute number of possible transitions in the current + # cell used to check for invalid actions + new_direction, transition_valid = self.check_action(agent, action) + new_position = get_new_position(agent.position, new_direction) + + new_cell_valid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_full_transitions(*new_position) > 0) + + # If transition validity hasn't been checked yet. + if transition_valid is None: + transition_valid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) + + # only call cell_free() if new cell is inside the scene + if new_cell_valid: + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_free = self.cell_free(new_position) + else: + # if new cell is outside of scene -> cell_free is False + cell_free = False + return cell_free, new_cell_valid, new_direction, new_position, transition_valid + + def record_timestep(self, dActions): + ''' Record the positions and orientations of all agents in memory, in the cur_episode + ''' + list_agents_state = [] + for i_agent in range(self.get_num_agents()): + agent = self.agents[i_agent] + # the int cast is to avoid numpy types which may cause problems with msgpack + # in env v2, agents may have position None, before starting + if agent.position is None: + pos = (0, 0) + else: + pos = (int(agent.position[0]), int(agent.position[1])) + # print("pos:", pos, type(pos[0])) + list_agents_state.append( + [*pos, int(agent.direction), agent.malfunction_data["malfunction"] ]) + + self.cur_episode.append(list_agents_state) + self.list_actions.append(dActions) + + def cell_free(self, position: IntVector2D) -> bool: + """ + Utility to check if a cell is free + + Parameters: + -------- + position : Tuple[int, int] + + Returns + ------- + bool + is the cell free or not? + + """ + return self.agent_positions[position] == -1 + + def check_action(self, agent: EnvAgent, action: RailEnvActions): + """ + + Parameters + ---------- + agent : EnvAgent + action : RailEnvActions + + Returns + ------- + Tuple[Grid4TransitionsEnum,Tuple[int,int]] + + + + """ + transition_valid = None + possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + new_direction = agent.direction + if action == RailEnvActions.MOVE_LEFT: + new_direction = agent.direction - 1 + if num_transitions <= 1: + transition_valid = False + + elif action == RailEnvActions.MOVE_RIGHT: + new_direction = agent.direction + 1 + if num_transitions <= 1: + transition_valid = False + + new_direction %= 4 + + if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1: + # - dead-end, straight line or curved line; + # new_direction will be the only valid transition + # - take only available transition + new_direction = np.argmax(possible_transitions) + transition_valid = True + return new_direction, transition_valid + + def _get_observations(self): + """ + Utility which returns the observations for an agent with respect to environment + + Returns + ------ + Dict object + """ + #print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") + self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) + return self.obs_dict + + def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + """ + Returns directions in which the agent can move + + Parameters: + --------- + row : int + col : int + + Returns: + ------- + List[int] + """ + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) + + + + def _exp_distirbution_synced(self, rate: float) -> float: + """ + Generates sample from exponential distribution + We need this to guarantee synchronity between different instances with same seed. + :param rate: + :return: + """ + u = self.np_random.rand() + x = - np.log(1 - u) * rate + return x + + def _is_agent_ok(self, agent: EnvAgent) -> bool: + """ + Check if an agent is ok, meaning it can move and is not malfuncitoinig + Parameters + ---------- + agent + + Returns + ------- + True if agent is ok, False otherwise + + """ + return agent.malfunction_data['malfunction'] < 1 + + def save(self, filename): + print("deprecated call to env.save() - pls call RailEnvPersister.save()") + persistence.RailEnvPersister.save(self, filename) + + diff --git a/src/train.py b/src/train.py index a7971b1..252516e 100644 --- a/src/train.py +++ b/src/train.py @@ -7,18 +7,19 @@ import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv from pathos import multiprocessing positive_infinity = int(1e5) negative_infinity = -positive_infinity try: + from .rail_env import RailEnv from .agent import Agent as DQN_Agent, device, BATCH_SIZE from .normalize_output_data import wrap from .observation_utils import normalize_observation, TreeObservation from .railway_utils import load_precomputed_railways, create_random_railways except: + from rail_env import RailEnv from agent import Agent as DQN_Agent, device, BATCH_SIZE from normalize_output_data import wrap from observation_utils import normalize_observation, TreeObservation From a55ffaf545ee6acf62dee1083af4c348eb894e48 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 16:43:03 +0200 Subject: [PATCH 20/75] style: pycharm reforamt --- src/rail_env.pyx | 52 +++++++++++++++++------------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index 6dc67bb..c014c1a 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -6,26 +6,23 @@ import random from enum import IntEnum from typing import List, NamedTuple, Optional, Dict -import msgpack import msgpack_numpy as m import numpy as np -from gym.utils import seeding -from msgpack import Packer - from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus -from flatland.envs.distance_map import DistanceMap - # Need to use circular imports for persistence. from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import persistence from flatland.envs import rail_generators as rail_gen from flatland.envs import schedule_generators as sched_gen -from flatland.envs import persistence +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.distance_map import DistanceMap +from flatland.envs.observations import GlobalObsForRailEnv +from gym.utils import seeding # Direct import of objects / classes does not work with circular imports. # from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData @@ -33,11 +30,6 @@ from flatland.envs import persistence # from flatland.envs.rail_generators import random_rail_generator, RailGenerator # from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator -from flatland.envs.observations import GlobalObsForRailEnv - - - -import pickle m.patch() @@ -131,10 +123,10 @@ class RailEnv(Environment): width, height, rail_generator = None, - schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), + schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), - malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), + malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), remove_agents_at_target=True, random_seed=1, record_steps=False @@ -231,7 +223,7 @@ class RailEnv(Environment): self.record_steps = record_steps # whether to save timesteps # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps] self.cur_episode = [] - self.list_actions = [] # save actions in here + self.list_actions = [] # save actions in here def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -264,8 +256,6 @@ class RailEnv(Environment): agent.reset() self.active_agents = [i for i in range(len(self.agents))] - - def action_required(self, agent): """ Check if an agent needs to provide an action @@ -281,8 +271,8 @@ class RailEnv(Environment): False: Agent cannot provide an action """ return (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None) -> (Dict, Dict): @@ -329,8 +319,6 @@ class RailEnv(Environment): if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) - - if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: agents_hints = None if optionals and 'agents_hints' in optionals: @@ -569,8 +557,8 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.stop_penalty if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or - action == RailEnvActions.STOP_MOVING): + action == RailEnvActions.DO_NOTHING or + action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True self.rewards_dict[i_agent] += self.start_penalty @@ -704,11 +692,11 @@ class RailEnv(Environment): new_position = get_new_position(agent.position, new_direction) new_cell_valid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_full_transitions(*new_position) > 0) + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_full_transitions(*new_position) > 0) # If transition validity hasn't been checked yet. if transition_valid is None: @@ -740,7 +728,7 @@ class RailEnv(Environment): pos = (int(agent.position[0]), int(agent.position[1])) # print("pos:", pos, type(pos[0])) list_agents_state.append( - [*pos, int(agent.direction), agent.malfunction_data["malfunction"] ]) + [*pos, int(agent.direction), agent.malfunction_data["malfunction"]]) self.cur_episode.append(list_agents_state) self.list_actions.append(dActions) @@ -828,8 +816,6 @@ class RailEnv(Environment): """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) - - def _exp_distirbution_synced(self, rate: float) -> float: """ Generates sample from exponential distribution @@ -858,5 +844,3 @@ class RailEnv(Environment): def save(self, filename): print("deprecated call to env.save() - pls call RailEnvPersister.save()") persistence.RailEnvPersister.save(self, filename) - - From 4c805d3d4d831e5bf1c18dac850bfd51a7b591aa Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 17:11:35 +0200 Subject: [PATCH 21/75] perf: add typehints to railenv --- src/rail_env.pyx | 175 ++++++++++++++++++++++------------------------- 1 file changed, 83 insertions(+), 92 deletions(-) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index c014c1a..ae73064 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -34,27 +34,24 @@ from gym.utils import seeding m.patch() -class RailEnvActions(IntEnum): - DO_NOTHING = 0 # implies change of direction in a dead-end! - MOVE_LEFT = 1 - MOVE_FORWARD = 2 - MOVE_RIGHT = 3 - STOP_MOVING = 4 - - @staticmethod - def to_char(a: int): - return { - 0: 'B', - 1: 'L', - 2: 'F', - 3: 'R', - 4: 'S', - }[a] +cdef int DO_NOTHING = 0 # implies change of direction in a dead-end! +cdef int MOVE_LEFT = 1 +cdef int MOVE_FORWARD = 2 +cdef int MOVE_RIGHT = 3 +cdef int STOP_MOVING = 4 +cdef int ACTION_COUNT = 5 + +cpdef str to_char(a: int): + return { + 0: 'B', + 1: 'L', + 2: 'F', + 3: 'R', + 4: 'S', + }[a] RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) -RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos), - ('next_direction', Grid4TransitionsEnum)]) class RailEnv(Environment): @@ -120,8 +117,8 @@ class RailEnv(Environment): start_penalty = 0 # penalty for starting a stopped agent def __init__(self, - width, - height, + int width, + int height, rail_generator = None, schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), number_of_agents=1, @@ -274,8 +271,9 @@ class RailEnv(Environment): agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03))) - def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, - random_seed: bool = None) -> (Dict, Dict): + def reset(self, bint regenerate_rail: bool = True, bint regenerate_schedule: bool = True, + bint activate_agents: bool = False, + bint random_seed: bool = False) -> (Dict, Dict): """ reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) @@ -303,7 +301,7 @@ class RailEnv(Environment): if random_seed: self._seed(random_seed) - optionals = {} + cdef dict optionals = {} if regenerate_rail or self.rail is None: rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets, self.np_random) @@ -345,7 +343,7 @@ class RailEnv(Environment): self._break_agent(agent) if agent.malfunction_data["malfunction"] > 0: - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING + agent.speed_data['transition_action_on_cellexit'] = DO_NOTHING # Fix agents that finished their malfunction self._fix_agent_after_malfunction(agent) @@ -366,16 +364,16 @@ class RailEnv(Environment): # Empty the episode store of agent positions self.cur_episode = [] - info_dict: Dict = { - 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, - 'malfunction': { - i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) - }, - 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, - 'status': {i: agent.status for i, agent in enumerate(self.agents)} - } + cdef dict info_dict = { + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) + }, + 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} + } # Return the new observation vectors for each agent - observation_dict: Dict = self._get_observations() + cdef dict observation_dict = self._get_observations() return observation_dict, info_dict def _fix_agent_after_malfunction(self, agent: EnvAgent): @@ -420,26 +418,25 @@ class RailEnv(Environment): return - def step(self, action_dict_: Dict[int, RailEnvActions]): + def step(self, dict action_dict_): """ Updates rewards for the agents at a step. - Parameters - ---------- - action_dict_ : Dict[int,RailEnvActions] - """ self._elapsed_steps += 1 # If we're done, set reward and info_dict and step() is done. + cdef dict info_dict = {} + cdef int i_agent = 0 + if self.dones["__all__"]: self.rewards_dict = {} - info_dict = { - "action_required": {}, - "malfunction": {}, - "speed": {}, - "status": {}, - } + info_dict = {"action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + for i_agent, agent in enumerate(self.agents): self.rewards_dict[i_agent] = self.global_reward info_dict["action_required"][i_agent] = False @@ -451,13 +448,12 @@ class RailEnv(Environment): # Reset the step rewards self.rewards_dict = dict() - info_dict = { - "action_required": {}, - "malfunction": {}, - "speed": {}, - "status": {}, - } - have_all_agents_ended = True # boolean flag to check if all agents are done + info_dict = {"action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + cdef bint have_all_agents_ended = True # boolean flag to check if all agents are done for i_agent, agent in enumerate(self.agents): # Reset the step rewards @@ -494,7 +490,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict - def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None): + def _step_agent(self, int i_agent, int action): """ Performs a step and step, start and stop penalty on a single agent in the following sub steps: - malfunction @@ -504,7 +500,6 @@ class RailEnv(Environment): Parameters ---------- i_agent : int - action_dict_ : Dict[int,RailEnvActions] """ agent = self.agents[i_agent] @@ -513,8 +508,7 @@ class RailEnv(Environment): # agent gets active by a MOVE_* action and if c if agent.status == RailAgentStatus.READY_TO_DEPART: - if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, - RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): + if action in [MOVE_LEFT, MOVE_RIGHT, MOVE_FORWARD] and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE self._set_agent_to_initial_position(agent, agent.initial_position) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] @@ -539,26 +533,26 @@ class RailEnv(Environment): if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: - action = RailEnvActions.DO_NOTHING + action = DO_NOTHING - if action < 0 or action > len(RailEnvActions): + if action < 0 or action > ACTION_COUNT: # print('ERROR: illegal action=', action, 'for agent with index=', i_agent, '"DO NOTHING" will be executed instead') - action = RailEnvActions.DO_NOTHING + action = DO_NOTHING - if action == RailEnvActions.DO_NOTHING and agent.moving: + if action == DO_NOTHING and agent.moving: # Keep moving - action = RailEnvActions.MOVE_FORWARD + action = MOVE_FORWARD - if action == RailEnvActions.STOP_MOVING and agent.moving: + if action == STOP_MOVING and agent.moving: # Only allow halting an agent on entering new cells. agent.moving = False self.rewards_dict[i_agent] += self.stop_penalty if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or - action == RailEnvActions.STOP_MOVING): + action == DO_NOTHING or + action == STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True self.rewards_dict[i_agent] += self.start_penalty @@ -576,12 +570,12 @@ class RailEnv(Environment): else: # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): + if (action == MOVE_LEFT or action == MOVE_RIGHT): _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + self._check_action_on_agent(MOVE_FORWARD, agent) if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + agent.speed_data['transition_action_on_cellexit'] = MOVE_FORWARD _action_stored = True if not _action_stored: @@ -668,12 +662,12 @@ class RailEnv(Environment): agent.position = None agent.status = RailAgentStatus.DONE_REMOVED - def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): + def _check_action_on_agent(self, int action, agent: EnvAgent): """ Parameters ---------- - action : RailEnvActions + action : agent : EnvAgent Returns @@ -688,36 +682,33 @@ class RailEnv(Environment): """ # compute number of possible transitions in the current # cell used to check for invalid actions - new_direction, transition_valid = self.check_action(agent, action) - new_position = get_new_position(agent.position, new_direction) + cdef tuple act_chk = self.check_action(agent, action) + cdef int new_direction = act_chk[0] + cdef bint transition_valid = act_chk[1] + cdef bint cell_free = False + cdef tuple new_position = get_new_position(agent.position, new_direction) - new_cell_valid = ( + cdef bint new_cell_valid = ( np.array_equal( # Check the new position is still in the grid new_position, np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) and # check the new position has some transitions (ie is not an empty cell) self.rail.get_full_transitions(*new_position) > 0) - # If transition validity hasn't been checked yet. - if transition_valid is None: - transition_valid = self.rail.get_transition( - (*agent.position, agent.direction), - new_direction) - # only call cell_free() if new cell is inside the scene if new_cell_valid: # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) cell_free = self.cell_free(new_position) - else: - # if new cell is outside of scene -> cell_free is False - cell_free = False + return cell_free, new_cell_valid, new_direction, new_position, transition_valid def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode ''' - list_agents_state = [] + cdef list list_agents_state = [] + cdef int i_agent = 0 + cdef tuple pos = tuple() for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] # the int cast is to avoid numpy types which may cause problems with msgpack @@ -749,13 +740,13 @@ class RailEnv(Environment): """ return self.agent_positions[position] == -1 - def check_action(self, agent: EnvAgent, action: RailEnvActions): + def check_action(self, agent: EnvAgent, action): """ Parameters ---------- agent : EnvAgent - action : RailEnvActions + action : Returns ------- @@ -764,24 +755,24 @@ class RailEnv(Environment): """ - transition_valid = None - possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) - num_transitions = np.count_nonzero(possible_transitions) + cdef bint transition_valid = False + cdef tuple possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) + cdef int num_transitions = np.count_nonzero(possible_transitions) + cdef int new_direction = agent.direction - new_direction = agent.direction - if action == RailEnvActions.MOVE_LEFT: + if action == MOVE_LEFT: new_direction = agent.direction - 1 if num_transitions <= 1: transition_valid = False - elif action == RailEnvActions.MOVE_RIGHT: + elif action == MOVE_RIGHT: new_direction = agent.direction + 1 if num_transitions <= 1: transition_valid = False new_direction %= 4 - if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1: + if action == MOVE_FORWARD and num_transitions == 1: # - dead-end, straight line or curved line; # new_direction will be the only valid transition # - take only available transition @@ -823,8 +814,8 @@ class RailEnv(Environment): :param rate: :return: """ - u = self.np_random.rand() - x = - np.log(1 - u) * rate + cdef float u = self.np_random.rand() + cdef float x = - np.log(1 - u) * rate return x def _is_agent_ok(self, agent: EnvAgent) -> bool: From e7b6a9f55b44e6ee3f2e6789993ced30463c655c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 17:18:34 +0200 Subject: [PATCH 22/75] fix: re-add support for unitialized transition --- src/rail_env.pyx | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index ae73064..942d192 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -684,7 +684,7 @@ class RailEnv(Environment): # cell used to check for invalid actions cdef tuple act_chk = self.check_action(agent, action) cdef int new_direction = act_chk[0] - cdef bint transition_valid = act_chk[1] + transition_valid = act_chk[1] cdef bint cell_free = False cdef tuple new_position = get_new_position(agent.position, new_direction) @@ -694,7 +694,10 @@ class RailEnv(Environment): np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) and # check the new position has some transitions (ie is not an empty cell) self.rail.get_full_transitions(*new_position) > 0) - + if transition_valid is None: + transition_valid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) # only call cell_free() if new cell is inside the scene if new_cell_valid: # Check the new position is not the same as any of the existing agent positions @@ -755,7 +758,7 @@ class RailEnv(Environment): """ - cdef bint transition_valid = False + transition_valid = None cdef tuple possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) cdef int num_transitions = np.count_nonzero(possible_transitions) cdef int new_direction = agent.direction From 0ff41d2ab38602d389e8c6c78c9306fde03ec9c3 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 10 Jul 2020 19:59:47 +0200 Subject: [PATCH 23/75] style: move cat to correct folder --- src/cat.py | 50 ++++++++++++++++++++++++++++++++++++++++++++ src/model.py | 18 +++++++++++++--- src/railway_utils.py | 6 +++--- src/train.py | 10 +++++---- 4 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 src/cat.py diff --git a/src/cat.py b/src/cat.py new file mode 100644 index 0000000..a36b8e2 --- /dev/null +++ b/src/cat.py @@ -0,0 +1,50 @@ +import os +import pickle +import random + +import tqdm + + +def main(bucket0, bucket1): + files = sorted([i for i in os.listdir() if i.endswith('pkl') and 'sum' not in i]) + print(f'Concatenating {", ".join(files)}') + + buckets = [{}, {}] + for fname in tqdm.tqdm(files, ncols=120, leave=False): + with open(fname, 'rb') as f: + try: + dat = pickle.load(f) + except Exception as e: + print(f'Caught {e} while processing {fname}') + dat = list(dat) + try: + _, _ = dat[0] + except: + buckets[1][fname.split('_')[-1].split('.pkl')[0]] = dat[:] + else: + buckets[0][fname.split('_')[-1].split('.pkl')[0]] = dat[:] + + def _get_itm(idx): + items = sorted(list(buckets[idx].items())) + random.seed(0) + random.shuffle(items) + names, items = list(zip(*items)) + print(f"First, Last in sequence: {names[0]}, {names[-1]}") + print(f"Random number _after_ shuffling (to check for seed consitency): {random.random()}") + buckets[idx] = [itm for lst in items for itm in lst] + random.shuffle(buckets[idx]) + + def _dump(idx, dump_name: str): + dump_name += 'sum.pkl' + with open(dump_name, 'wb') as f: + pickle.dump(buckets[idx], f) + print(f'Dumped {len(buckets[idx])} items from {len(files)} sources to {dump_name}') + + _get_itm(0) + _get_itm(1) + _dump(0, bucket0) + _dump(1, bucket1) + + +if __name__ == '__main__': + main('rail_networks_', 'schedules_') diff --git a/src/model.py b/src/model.py index b26360f..a0a64bc 100644 --- a/src/model.py +++ b/src/model.py @@ -36,7 +36,10 @@ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True self.bias = torch.nn.Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) - self._kwargs = {'bias': self.bias, 'padding': padding, 'dilation': dilation, 'groups': groups, 'stride': stride} + self.padding = padding + self.dilation = dilation + self.groups = groups + self.stride = stride self._function = function def forward(self, fn_input): @@ -45,7 +48,13 @@ def forward(self, fn_input): else: weight = self.weight - return self._function(fn_input, weight, **self._kwargs) + return self._function(fn_input, + weight, + bias=self.bias, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + stride=self.stride) def extra_repr(self): return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) @@ -283,13 +292,16 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): + # SOMEWHAT OKAYISH # model = torch.nn.Sequential(WeightDropConv(state_size + 1, 11 * hidden_factor, bias=False), # torch.nn.BatchNorm1d(11 * hidden_factor), # Mish(), # WeightDropConv(11 * hidden_factor, action_size)) + + # FAST DEBUG model = torch.nn.Sequential(torch.nn.Conv1d(state_size + 1, 20, 1, bias=False), torch.nn.BatchNorm1d(20), - torch.nn.ReLU6(), + Mish(), torch.nn.Conv1d(20, action_size, 1)) print(model) return torch.jit.script(model) diff --git a/src/railway_utils.py b/src/railway_utils.py index a533fa2..b17c025 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -43,9 +43,9 @@ def __call__(self, *args, **kwargs): # Helper function to load in precomputed railway networks def load_precomputed_railways(project_root, start_index): prefix = os.path.join(project_root, 'railroads') - suffix = f'_3x30x30.pkl' - sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) - rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + suffix = f'_sum.pkl' + rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) print(f"Working on {len(rail)} tracks") return rail, sched diff --git a/src/train.py b/src/train.py index 252516e..e0f807e 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,4 @@ import argparse -import copy import time from itertools import zip_longest from pathlib import Path @@ -92,8 +91,9 @@ MalfunctionParameters(1 / 8000, 15, 50)), obs_builder_object=(GlobalObsForRailEnv() if flags.global_environment - else TreeObservation(max_depth=flags.tree_depth)) - ) for _ in range(BATCH_SIZE)] + else TreeObservation(max_depth=flags.tree_depth)), + random_seed=i) + for i in range(BATCH_SIZE)] env = environments[0] # After training we want to render the results so we also load a renderer @@ -154,8 +154,10 @@ def normalize(observation, target_tensor): for episode in range(start + 1, flags.num_episodes + 1): agent.reset() obs, info = zip(*[env.reset(True, True) for env in environments]) + score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) + print(agent_count) agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) agent_obs_buffer = agent_obs.clone() @@ -226,7 +228,7 @@ def normalize(observation, target_tensor): current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) - print(f'\rBatch {episode:<4} - Episode {BATCH_SIZE * episode:<6}' + print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6}' f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' From 4b3b66674ce5801bfbaf77a11213024010aeb53f Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 11 Jul 2020 18:37:04 +0200 Subject: [PATCH 24/75] perf: use ppo --- src/agent.py | 70 +++++++++++++----------------- src/generate_railways.py | 6 ++- src/rail_env.pyx | 94 +++++++++++++--------------------------- src/railway_utils.py | 15 ++++--- src/train.py | 5 +-- 5 files changed, 76 insertions(+), 114 deletions(-) diff --git a/src/agent.py b/src/agent.py index 9a70d7a..85b176d 100644 --- a/src/agent.py +++ b/src/agent.py @@ -13,9 +13,10 @@ import os BUFFER_SIZE = 500_000 -BATCH_SIZE = 64 +BATCH_SIZE = 256 GAMMA = 0.998 TAU = 1e-3 +CLIP_FACTOR = 0.2 LR = 2e-4 UPDATE_EVERY = 1 DOUBLE_DQN = False @@ -33,20 +34,21 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s network = ConvNetwork else: network = QNetwork - self.qnetwork_local = network(state_size, - action_size, - hidden_factor, - model_depth, - kernel_size, - squeeze_heads).to(device) - self.qnetwork_target = network(state_size, - action_size, - hidden_factor, - model_depth, - kernel_size, - squeeze_heads, - debug=False).to(device) - self.optimizer = Optimizer(self.qnetwork_local.parameters(), lr=LR, weight_decay=1e-2) + self.policy = network(state_size, + action_size, + hidden_factor, + model_depth, + kernel_size, + squeeze_heads).to(device) + self.old_policy = network(state_size, + action_size, + hidden_factor, + model_depth, + kernel_size, + squeeze_heads, + debug=False).to(device) + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) # Replay memory self.memory = ReplayBuffer(BATCH_SIZE) @@ -62,9 +64,9 @@ def act(self, state, eps=0.): agent_count = len(state) state = torch.stack(state, -1).unsqueeze(0).to(device) state = torch.cat([state, torch.randn(1, 1, state.size(-1), device=device)], 1) - self.qnetwork_local.eval() + self.policy.eval() with torch.no_grad(): - action_values = self.qnetwork_local(state) + action_values = self.policy(state) # Epsilon-greedy action selection return [torch.argmax(action_values[:, :, i], 1).item() @@ -75,9 +77,9 @@ def act(self, state, eps=0.): def multi_act(self, state, eps=0.): agent_count = state.size(-1) state = torch.cat([state.to(device), torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) - self.qnetwork_local.eval() + self.policy.eval() with torch.no_grad(): - action_values = self.qnetwork_local(state) + action_values = self.policy(state) # Epsilon-greedy action selection return [[torch.argmax(act[:, i], 0).item() @@ -112,31 +114,21 @@ def step(self, state, action, next_state, agent_done, episode_done, collision, s self.learn(state, action, reward, next_state, dones) def learn(self, states, actions, rewards, next_states, dones): - self.qnetwork_local.train() + self.policy.train() - actions.squeeze_(-1) - dones.squeeze_(-1) + responsible_outputs = torch.gather(self.policy(states), 1, actions) + old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() - # Get expected Q values from local model - Q_expected = self.qnetwork_local(states.squeeze(1)) + # rewards = rewards - rewards.mean() + ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) + loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() - if DOUBLE_DQN: - Q_expected = Q_expected.gather(1, actions.unsqueeze(1)) - Q_best_action = self.qnetwork_local(next_states.squeeze(1)).argmax(1, keepdim=True) - Q_targets_next = self.qnetwork_target(next_states.squeeze(1)).gather(1, Q_best_action) - - # Compute loss and perform a gradient step - loss = (GAMMA * Q_targets_next * (1 - dones.unsqueeze(-2)) - - Q_expected - rewards.unsqueeze(-1)).square().mean() - else: - loss = (Q_expected.gather(1, actions.unsqueeze(1)) * rewards.unsqueeze(-1)).clamp(min=0).mean() + # Compute loss and perform a gradient step + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer.zero_grad() loss.backward() self.optimizer.step() - self.optimizer.zero_grad() - - # Update the target network parameters to `tau * local.parameters() + (1 - tau) * target.parameters()` - for target_param, local_param in zip(self.qnetwork_target.parameters(), self.qnetwork_local.parameters()): - target_param.data.copy_(TAU * local_param.data + (1.0 - TAU) * target_param.data) # Checkpointing methods diff --git a/src/generate_railways.py b/src/generate_railways.py index 2b05355..9df58fe 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -17,10 +17,12 @@ parser.add_argument("--agents", type=int, default=3, help="Number of episodes to train for") parser.add_argument("--cities", type=int, default=3, help="Number of episodes to train for") parser.add_argument("--width", type=int, default=35, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--height", type=int, default=None, help="Decay factor for epsilon-greedy exploration") flags = parser.parse_args() -width = height = flags.width +width = flags.width +height = width if flags.height is None else flags.height n_agents = flags.agents rail_generator, schedule_generator = create_random_railways(project_root, flags.cities) @@ -48,7 +50,7 @@ def do(schedules: list, rail_networks: list): shared_schedules = manager.list(schedules) shared_rail_networks = manager.list(rail_networks) # Generate 10000 random railways in 100 batches of 100 -for _ in tqdm(range(100), ncols=120, leave=False): +for _ in tqdm(range(500), ncols=150, leave=False): do(schedules, rail_networks) with open(project_root / f'railroads/rail_networks_{n_agents}x{width}x{height}.pkl', 'wb') as file: pickle.dump(schedules, file, protocol=4) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index 942d192..842e884 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -3,14 +3,13 @@ Definition of the RailEnv environment. """ import random # TODO: _ this is a global method --> utils or remove later -from enum import IntEnum from typing import List, NamedTuple, Optional, Dict import msgpack_numpy as m import numpy as np from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions +from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap @@ -33,7 +32,6 @@ from gym.utils import seeding m.patch() - cdef int DO_NOTHING = 0 # implies change of direction in a dead-end! cdef int MOVE_LEFT = 1 cdef int MOVE_FORWARD = 2 @@ -50,7 +48,6 @@ cpdef str to_char(a: int): 4: 'S', }[a] - RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) @@ -205,11 +202,9 @@ class RailEnv(Environment): self.action_space = [5] - self._seed() - self._seed() self.random_seed = random_seed - if self.random_seed: - self._seed(seed=random_seed) + self.np_random, seed = seeding.np_random(random_seed) + random.seed(seed) self.valid_positions = None @@ -222,30 +217,9 @@ class RailEnv(Environment): self.cur_episode = [] self.list_actions = [] # save actions in here - def _seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - random.seed(seed) - return [seed] - - # no more agent_handles - def get_agent_handles(self): - return range(self.get_num_agents()) - def get_num_agents(self) -> int: return len(self.agents) - def add_agent(self, agent): - """ Add static info for a single agent. - Returns the index of the new agent. - """ - self.agents.append(agent) - return len(self.agents) - 1 - - def set_agent_active(self, agent: EnvAgent): - if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): - agent.status = RailAgentStatus.ACTIVE - self._set_agent_to_initial_position(agent, agent.initial_position) - def reset_agents(self): """ Reset the agents to their starting positions """ @@ -271,9 +245,7 @@ class RailEnv(Environment): agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03))) - def reset(self, bint regenerate_rail: bool = True, bint regenerate_schedule: bool = True, - bint activate_agents: bool = False, - bint random_seed: bool = False) -> (Dict, Dict): + def reset(self) -> (Dict, Dict): """ reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) @@ -298,37 +270,32 @@ class RailEnv(Environment): """ - if random_seed: - self._seed(random_seed) - cdef dict optionals = {} - if regenerate_rail or self.rail is None: - rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets, - self.np_random) + rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets, + self.np_random) - self.rail = rail - self.height, self.width = self.rail.grid.shape + self.rail = rail + self.height, self.width = self.rail.grid.shape - # Do a new set_env call on the obs_builder to ensure - # that obs_builder specific instantiations are made according to the - # specifications of the current environment : like width, height, etc - self.obs_builder.set_env(self) + # Do a new set_env call on the obs_builder to ensure + # that obs_builder specific instantiations are made according to the + # specifications of the current environment : like width, height, etc + self.obs_builder.set_env(self) if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) - if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: - agents_hints = None - if optionals and 'agents_hints' in optionals: - agents_hints = optionals['agents_hints'] + agents_hints = None + if optionals and 'agents_hints' in optionals: + agents_hints = optionals['agents_hints'] - schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets, - self.np_random) - self.agents = EnvAgent.from_schedule(schedule) + schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets, + self.np_random) + self.agents = EnvAgent.from_schedule(schedule) - # Get max number of allowed time steps from schedule generator - # Look at the specific schedule generator used to see where this number comes from - self._max_episode_steps = schedule.max_episode_steps + # Get max number of allowed time steps from schedule generator + # Look at the specific schedule generator used to see where this number comes from + self._max_episode_steps = schedule.max_episode_steps self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 @@ -336,9 +303,6 @@ class RailEnv(Environment): self.reset_agents() for agent in self.agents: - # Induce malfunctions - if activate_agents: - self.set_agent_active(agent) self._break_agent(agent) @@ -365,13 +329,13 @@ class RailEnv(Environment): self.cur_episode = [] cdef dict info_dict = { - 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, - 'malfunction': { - i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) - }, - 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, - 'status': {i: agent.status for i, agent in enumerate(self.agents)} - } + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) + }, + 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} + } # Return the new observation vectors for each agent cdef dict observation_dict = self._get_observations() return observation_dict, info_dict @@ -535,7 +499,7 @@ class RailEnv(Environment): if action is None: action = DO_NOTHING - if action < 0 or action > ACTION_COUNT: # + if action < 0 or action > ACTION_COUNT: # print('ERROR: illegal action=', action, 'for agent with index=', i_agent, '"DO NOTHING" will be executed instead') diff --git a/src/railway_utils.py b/src/railway_utils.py index b17c025..2584919 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -41,11 +41,16 @@ def __call__(self, *args, **kwargs): # Helper function to load in precomputed railway networks -def load_precomputed_railways(project_root, start_index): +def load_precomputed_railways(project_root, start_index, big=False): prefix = os.path.join(project_root, 'railroads') - suffix = f'_sum.pkl' - rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) - sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + if big: + suffix = f'_50x35x20.pkl' + else: + suffix = f'_3x30x30.pkl' + sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + if big: + sched, rail = rail, sched print(f"Working on {len(rail)} tracks") return rail, sched @@ -59,6 +64,6 @@ def create_random_railways(project_root, max_cities=5): 1 / 4.: 0.0} # Slow freight train rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=max_cities, - max_rails_between_cities=max_cities - 1, max_rails_in_city=max_cities - 1) + max_rails_between_cities=2, max_rails_in_city=3) schedule_generator = sparse_schedule_generator(speed_ratio_map) return rail_generator, schedule_generator diff --git a/src/train.py b/src/train.py index e0f807e..fe94b7a 100644 --- a/src/train.py +++ b/src/train.py @@ -88,7 +88,7 @@ rail_generator=rail_generator, schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params( - MalfunctionParameters(1 / 8000, 15, 50)), + MalfunctionParameters(1 / 500, 20, 50)), obs_builder_object=(GlobalObsForRailEnv() if flags.global_environment else TreeObservation(max_depth=flags.tree_depth)), @@ -153,11 +153,10 @@ def normalize(observation, target_tensor): # Main training loop for episode in range(start + 1, flags.num_episodes + 1): agent.reset() - obs, info = zip(*[env.reset(True, True) for env in environments]) + obs, info = zip(*[env.reset() for env in environments]) score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) - print(agent_count) agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) agent_obs_buffer = agent_obs.clone() From 18e5597f3c40ada517d0744e4c52f676b32079ac Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 11 Jul 2020 21:22:39 +0200 Subject: [PATCH 25/75] feat: give ppo network previous state --- src/agent.py | 47 +++++++++++++++++++++----------------------- src/model.py | 26 +++++++++++++----------- src/railway_utils.py | 2 +- src/train.py | 8 +++----- 4 files changed, 40 insertions(+), 43 deletions(-) diff --git a/src/agent.py b/src/agent.py index 85b176d..f060c76 100644 --- a/src/agent.py +++ b/src/agent.py @@ -24,6 +24,8 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + class Agent: def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, use_global=False): @@ -52,7 +54,7 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s # Replay memory self.memory = ReplayBuffer(BATCH_SIZE) - self.stack = [[] for _ in range(5)] + self.stack = [[] for _ in range(4)] self.t_step = 0 def reset(self): @@ -63,7 +65,6 @@ def reset(self): def act(self, state, eps=0.): agent_count = len(state) state = torch.stack(state, -1).unsqueeze(0).to(device) - state = torch.cat([state, torch.randn(1, 1, state.size(-1), device=device)], 1) self.policy.eval() with torch.no_grad(): action_values = self.policy(state) @@ -76,7 +77,7 @@ def act(self, state, eps=0.): def multi_act(self, state, eps=0.): agent_count = state.size(-1) - state = torch.cat([state.to(device), torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) + state = state.to(device) self.policy.eval() with torch.no_grad(): action_values = self.policy(state) @@ -90,40 +91,36 @@ def multi_act(self, state, eps=0.): # Record the results of the agent's action and update the model - def step(self, state, action, next_state, agent_done, episode_done, collision, step_reward=-1): + def step(self, state, action, agent_done, collision, step_reward=0, collision_reward=-2): self.stack[0].append(state) self.stack[1].append(action) - self.stack[2].append(next_state) - self.stack[3].append([[v or episode_done for k, v in a.items() + self.stack[2].append([[v for k, v in a.items() if not hasattr(k, 'startswith') or not k.startswith('_')] for a in agent_done]) - self.stack[4].append(collision) + self.stack[3].append(collision) if len(self.stack) >= UPDATE_EVERY: action = torch.tensor(self.stack[1]).flatten(0, 1).to(device) reward = torch.tensor([[[1 if ad - else (-5 if c else step_reward) for ad, c in zip(ad_batch, c_batch)] + else (collision_reward if c else step_reward) for ad, c in zip(ad_batch, c_batch)] for ad_batch, c_batch in zip(ad_step, c_step)] - for ad_step, c_step in zip(self.stack[3], self.stack[4])]).flatten(0, 1).to(device) - dones = torch.tensor(self.stack[3]).flatten(0, 1).to(device).float() + for ad_step, c_step in zip(self.stack[2], self.stack[3])]).flatten(0, 1).to(device) state = torch.cat(self.stack[0], 0).to(device) - next_state = torch.cat(self.stack[2], 0).to(device) - state = torch.cat([state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) - next_state = torch.cat([next_state, torch.randn(state.size(0), 1, state.size(-1), device=device)], 1) - self.stack = [[] for _ in range(5)] - self.learn(state, action, reward, next_state, dones) + self.stack = [[] for _ in range(4)] + self.learn(state, action, reward) - def learn(self, states, actions, rewards, next_states, dones): + def learn(self, states, actions, rewards): self.policy.train() - - responsible_outputs = torch.gather(self.policy(states), 1, actions) - old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() - - # rewards = rewards - rewards.mean() + actions.unsqueeze_(1) + responsible_outputs = self.policy(states).gather(1, actions) + old_responsible_outputs = self.old_policy(states).gather(1, actions) + old_responsible_outputs.detach_() ratio = responsible_outputs / (old_responsible_outputs + 1e-5) clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() + # rewards = rewards - rewards.mean() + # Compute loss and perform a gradient step self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer.zero_grad() @@ -133,8 +130,8 @@ def learn(self, states, actions, rewards, next_states, dones): # Checkpointing methods def save(self, path, *data): - torch.save(self.qnetwork_local.state_dict(), path / 'dqn/model_checkpoint.local') - torch.save(self.qnetwork_target.state_dict(), path / 'dqn/model_checkpoint.target') + torch.save(self.policy.state_dict(), path / 'dqn/model_checkpoint.local') + torch.save(self.old_policy.state_dict(), path / 'dqn/model_checkpoint.target') torch.save(self.optimizer.state_dict(), path / 'dqn/model_checkpoint.optimizer') with open(path / 'dqn/model_checkpoint.meta', 'wb') as file: pickle.dump(data, file) @@ -144,8 +141,8 @@ def load(self, path, *defaults): try: print("Loading model from checkpoint...") dqn = os.path.join(path, 'dqn') - self.qnetwork_local.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.local'), **loc)) - self.qnetwork_target.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.target'), **loc)) + self.policy.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.local'), **loc)) + self.old_policy.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.target'), **loc)) self.optimizer.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.optimizer'), **loc)) with open(os.path.join(dqn, 'model_checkpoint.meta'), 'rb') as file: return pickle.load(file) diff --git a/src/model.py b/src/model.py index a0a64bc..0ce4699 100644 --- a/src/model.py +++ b/src/model.py @@ -1,5 +1,6 @@ import typing - +import numpy as np +import math import torch @@ -292,16 +293,17 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): - # SOMEWHAT OKAYISH - # model = torch.nn.Sequential(WeightDropConv(state_size + 1, 11 * hidden_factor, bias=False), - # torch.nn.BatchNorm1d(11 * hidden_factor), - # Mish(), - # WeightDropConv(11 * hidden_factor, action_size)) - - # FAST DEBUG - model = torch.nn.Sequential(torch.nn.Conv1d(state_size + 1, 20, 1, bias=False), - torch.nn.BatchNorm1d(20), + model = torch.nn.Sequential(torch.nn.Conv1d(2*state_size, 33, 1, groups=11, bias=False), + torch.nn.BatchNorm1d(33), + Mish(), + WeightDropConv(33, 33), + torch.nn.BatchNorm1d(33), Mish(), - torch.nn.Conv1d(20, action_size, 1)) - print(model) + WeightDropConv(33, action_size, 1)) + if debug: + parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())) + digits = int(math.log10(parameters)) + number_string = " kMGTPEZY"[digits // 3] + + print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") return torch.jit.script(model) diff --git a/src/railway_utils.py b/src/railway_utils.py index 2584919..a81d511 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -41,7 +41,7 @@ def __call__(self, *args, **kwargs): # Helper function to load in precomputed railway networks -def load_precomputed_railways(project_root, start_index, big=False): +def load_precomputed_railways(project_root, start_index, big=True): prefix = os.path.join(project_root, 'railroads') if big: suffix = f'_50x35x20.pkl' diff --git a/src/train.py b/src/train.py index fe94b7a..e0cb12c 100644 --- a/src/train.py +++ b/src/train.py @@ -167,9 +167,9 @@ def normalize(observation, target_tensor): for step in range(max_steps): update_values = [[False] * agent_count for _ in range(BATCH_SIZE)] action_dict = [{} for _ in range(BATCH_SIZE)] - + input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) if any(any(inf['action_required']) for inf in info): - ret_action = agent.multi_act(agent_obs.flatten(1, 2), eps=eps) + ret_action = agent.multi_act(input_tensor, eps=eps) else: ret_action = update_values for idx, act_list in enumerate(ret_action): @@ -194,11 +194,9 @@ def normalize(observation, target_tensor): # Update replay buffer and train agent if flags.train and (any(update_values) or all_done or all(any(d) for d in done)): - agent.step(agent_obs_buffer.flatten(1, 2), + agent.step(input_tensor, agent_action_buffer, - agent_obs.flatten(1, 2), done, - all_done, [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)], flags.step_reward) agent_obs_buffer = agent_obs.clone() From af8952af6fedd49ea90b790fe1348f0987a77ce2 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 09:04:04 +0200 Subject: [PATCH 26/75] feat: railway_utils.py increase agent count over time --- src/railway_utils.py | 56 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/src/railway_utils.py b/src/railway_utils.py index a81d511..9a99b95 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -3,6 +3,11 @@ from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator +import numpy as np +try: + from .agent import BATCH_SIZE +except: + from agent import BATCH_SIZE class Generator: @@ -40,6 +45,44 @@ def __call__(self, *args, **kwargs): return next(self) +class RailGenerator: + def __init__(self, width=35, base=1.5): + self.rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=3, max_rails_between_cities=2, + max_rails_in_city=3) + self.sub_idx = 0 + self.top_idx = 0 + self.width = width + self.base = base + + def __next__(self): + self.sub_idx += 1 + if self.sub_idx == BATCH_SIZE: + self.sub_idx = 0 + self.top_idx += 1 + return self.rail_generator(self.width, self.width, int(2 * self.base ** self.top_idx), np_random=np.random) + + def __call__(self, *args, **kwargs): + return next(self) + + +class ScheduleGenerator: + def __init__(self, base=1.5): + self.schedule_generator = sparse_schedule_generator({1.: 1.}) + self.sub_idx = 0 + self.top_idx = 0 + self.base = base + + def __next__(self, rail, hints): + if self.sub_idx == BATCH_SIZE: + self.sub_idx = 0 + self.top_idx += 1 + self.sub_idx += 1 + return self.schedule_generator(rail, int(2 * self.base ** self.top_idx), hints, np_random=np.random) + + def __call__(self, rail, _, hints, *args, **kwargs): + return self.__next__(rail, hints) + + # Helper function to load in precomputed railway networks def load_precomputed_railways(project_root, start_index, big=True): prefix = os.path.join(project_root, 'railroads') @@ -56,14 +99,5 @@ def load_precomputed_railways(project_root, start_index, big=True): # Helper function to generate railways on the fly -def create_random_railways(project_root, max_cities=5): - speed_ratio_map = { - 1 / 1: 1.0, # Fast passenger train - 1 / 2.: 0.0, # Fast freight train - 1 / 3.: 0.0, # Slow commuter train - 1 / 4.: 0.0} # Slow freight train - - rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=max_cities, - max_rails_between_cities=2, max_rails_in_city=3) - schedule_generator = sparse_schedule_generator(speed_ratio_map) - return rail_generator, schedule_generator +def create_random_railways(base=1.1): + return RailGenerator(base=base), ScheduleGenerator(base=base) From a904cf4b75ba239e9ea8a9e06c5b5e990c4ff797 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 09:07:14 +0200 Subject: [PATCH 27/75] feat(train): add comfort functions to train.py --- src/train.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/train.py b/src/train.py index e0cb12c..2338d56 100644 --- a/src/train.py +++ b/src/train.py @@ -8,6 +8,8 @@ from flatland.envs.observations import GlobalObsForRailEnv from pathos import multiprocessing +torch.jit.optimized_execution(True) + positive_infinity = int(1e5) negative_infinity = -positive_infinity @@ -32,7 +34,7 @@ parser.add_argument("--train", type=boolean, default=True, help="Whether to train the model or just evaluate it") parser.add_argument("--load-model", default=False, action='store_true', help="Whether to load the model from the last checkpoint") -parser.add_argument("--load-railways", type=boolean, default=True, +parser.add_argument("--load-railways", type=boolean, default=False, help="Whether to load in pre-generated railway networks") parser.add_argument("--report-interval", type=int, default=100, help="Iterations between reports") parser.add_argument("--render-interval", type=int, default=0, help="Iterations between renders") @@ -51,7 +53,8 @@ parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") -parser.add_argument("--step-reward", type=float, default=-1, help="Depth of the observation tree") +parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") +parser.add_argument("--collision-reward", type=float, default=-2, help="Depth of the observation tree") parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") parser.add_argument("--threads", type=int, default=1, help="Depth of the observation tree") @@ -81,7 +84,7 @@ if flags.load_railways: rail_generator, schedule_generator = load_precomputed_railways(project_root, start) else: - rail_generator, schedule_generator = create_random_railways(project_root) + rail_generator, schedule_generator = create_random_railways(1.1) # Create the Flatland environment environments = [RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, @@ -99,7 +102,7 @@ # After training we want to render the results so we also load a renderer # Add some variables to keep track of the progress -current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = current_taken = mean_taken = 0 +current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = current_taken = mean_taken = None agent_action_buffer = [] start_time = time.time() @@ -126,7 +129,7 @@ def is_collision(a, i): def get_means(x, y, c, s): - return (x * 3 + c) / 4, (y * (s - 1) + c) / s + return c if x is None else (x * 3 + c) / 4, c if y is None else (y * (s - 1) + c) / s chunk_size = (BATCH_SIZE + 1) // flags.threads @@ -198,7 +201,8 @@ def normalize(observation, target_tensor): agent_action_buffer, done, [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)], - flags.step_reward) + flags.step_reward, + flags.collision_reward) agent_obs_buffer = agent_obs.clone() for idx, act in enumerate(action_dict): for key, value in act.items(): @@ -225,7 +229,7 @@ def normalize(observation, target_tensor): current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) - print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6}' + print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6} - Agents: {agent_count:>3}' f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' @@ -233,8 +237,6 @@ def normalize(observation, target_tensor): f' | Epsilon: {eps:.2f}' f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):.4f}s', end='') - if episode % flags.report_interval == 0: - print("") - if flags.train: - agent.save(project_root / 'checkpoints', episode, eps) - # Add stats to the tensorboard summary + print("") + if flags.train: + agent.save(project_root / 'checkpoints', episode, eps) From cd6207ff34943ed2c5ab6115e740196b060684ca Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 09:07:42 +0200 Subject: [PATCH 28/75] fix(model): use dropout instead of bernoulli, init properly --- src/model.py | 61 +++++++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/src/model.py b/src/model.py index 0ce4699..a4f2e86 100644 --- a/src/model.py +++ b/src/model.py @@ -1,6 +1,7 @@ +import math import typing + import numpy as np -import math import torch @@ -26,8 +27,6 @@ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True padding=0, dilation=1, function=torch.nn.functional.conv1d, stride=1): super().__init__() self.weight_dropout = weight_dropout - self.in_features = in_features - self.out_features = out_features if in_features % groups != 0: print(f"[ERROR] Unable to get weight for in={in_features},groups={groups}. Make sure they are divisible.") if out_features % groups != 0: @@ -44,22 +43,14 @@ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True self._function = function def forward(self, fn_input): - if self.training: - weight = self.weight.bernoulli(p=self.weight_dropout) * self.weight - else: - weight = self.weight - return self._function(fn_input, - weight, + torch.nn.functional.dropout(self.weight, self.weight_dropout, self.training), bias=self.bias, padding=self.padding, dilation=self.dilation, groups=self.groups, stride=self.stride) - def extra_repr(self): - return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) - class SeparableConvolution(torch.nn.Module): def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], @@ -228,7 +219,15 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out = self.linear1(mish(self.mid_norm(self.linear0(mish(self.init_norm(out)))))) return out - +def init(module: torch.nn.Module): + if hasattr(module, "weight") and hasattr(module.weight, "data"): + if "norm" in module.__class__.__name__.lower() or ( + hasattr(module, "__str__") and "norm" in str(module).lower()): + torch.nn.init.uniform_(module.weight.data, 0.998, 1.002) + else: + torch.nn.init.orthogonal_(module.weight.data) + if hasattr(module, "bias") and hasattr(module.bias, "data"): + torch.nn.init.constant_(module.bias.data, 0) # class QNetwork(torch.nn.Module): # def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, # debug=True): @@ -263,15 +262,6 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: # kernel_size=kernel_size, # squeeze_heads=squeeze_heads)]) # -# def init(module: torch.nn.Module): -# if hasattr(module, "weight") and hasattr(module.weight, "data"): -# if "norm" in module.__class__.__name__.lower() or ( -# hasattr(module, "__str__") and "norm" in str(module).lower()): -# torch.nn.init.uniform_(module.weight.data, 0.998, 1.002) -# else: -# torch.nn.init.orthogonal_(module.weight.data) -# if hasattr(module, "bias") and hasattr(module.bias, "data"): -# torch.nn.init.constant_(module.bias.data, 0) # # net.apply(init) # @@ -291,19 +281,32 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: # out = module(out) # return out +class Residual(torch.nn.Module): + def __init__(self, features): + super(Residual, self).__init__() + self.norm = torch.nn.BatchNorm1d(features) + self.conv = WeightDropConv(features, features) + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + return self.conv(mish(self.norm(fn_input))) + fn_input + + def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): - model = torch.nn.Sequential(torch.nn.Conv1d(2*state_size, 33, 1, groups=11, bias=False), - torch.nn.BatchNorm1d(33), - Mish(), - WeightDropConv(33, 33), - torch.nn.BatchNorm1d(33), + model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 55, 1, groups=11, bias=False), + Residual(55), + torch.nn.BatchNorm1d(55), Mish(), - WeightDropConv(33, action_size, 1)) + WeightDropConv(55, action_size, 1)) if debug: parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())) digits = int(math.log10(parameters)) number_string = " kMGTPEZY"[digits // 3] print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") - return torch.jit.script(model) + model.apply(init) + try: + model = torch.jit.script(model) + except TypeError: + pass + return model From 9fd919f509bde04bc9418802549c2084c3353ee0 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 09:15:49 +0200 Subject: [PATCH 29/75] feat(train): add cli args for width/agents of env --- src/railway_utils.py | 4 ++-- src/train.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/railway_utils.py b/src/railway_utils.py index 9a99b95..1b07525 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -99,5 +99,5 @@ def load_precomputed_railways(project_root, start_index, big=True): # Helper function to generate railways on the fly -def create_random_railways(base=1.1): - return RailGenerator(base=base), ScheduleGenerator(base=base) +def create_random_railways(width, base=1.1): + return RailGenerator(width=width, base=base), ScheduleGenerator(base=base) diff --git a/src/train.py b/src/train.py index 2338d56..d0b7940 100644 --- a/src/train.py +++ b/src/train.py @@ -49,6 +49,9 @@ parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") +parser.add_argument("--environment-width", type=int, default=35, help="Depth of the observation tree") +parser.add_argument("--agent-factor", type=float, default=1.1, help="Depth of the observation tree") + # Training parameters parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") @@ -84,7 +87,7 @@ if flags.load_railways: rail_generator, schedule_generator = load_precomputed_railways(project_root, start) else: - rail_generator, schedule_generator = create_random_railways(1.1) + rail_generator, schedule_generator = create_random_railways(flags.environment_width, flags.agent_factor) # Create the Flatland environment environments = [RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, From d284c57a01d250a8eea8f78ab73c90f4fc633121 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 11:30:11 +0200 Subject: [PATCH 30/75] feat(rail-generator): cythonize --- src/cythonize.sh | 5 +- src/generate_railways.py | 20 +-- src/rail_generators.pyx | 285 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 295 insertions(+), 15 deletions(-) create mode 100644 src/rail_generators.pyx diff --git a/src/cythonize.sh b/src/cythonize.sh index 2f38ba2..b7762f5 100644 --- a/src/cythonize.sh +++ b/src/cythonize.sh @@ -1,7 +1,7 @@ function compile { file=${1} cython "$file.pyx" -3 -Wextra -D - cmd="gcc-7 $file.c `python3-config --cflags --ldflags --includes --libs` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" + cmd="gcc-7 $file.c `python3-config --cflags --ldflags --includes --libs` -I`python -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" echo "Executing $cmd" $cmd echo "Testing compilation.." @@ -10,4 +10,5 @@ function compile { } compile observation_utils -compile rail_env \ No newline at end of file +compile rail_env +compile rail_generators \ No newline at end of file diff --git a/src/generate_railways.py b/src/generate_railways.py index 9df58fe..a535155 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -14,23 +14,17 @@ project_root = Path(__file__).resolve().parent.parent parser = argparse.ArgumentParser(description="Train an agent in the flatland environment") -parser.add_argument("--agents", type=int, default=3, help="Number of episodes to train for") -parser.add_argument("--cities", type=int, default=3, help="Number of episodes to train for") parser.add_argument("--width", type=int, default=35, help="Decay factor for epsilon-greedy exploration") -parser.add_argument("--height", type=int, default=None, help="Decay factor for epsilon-greedy exploration") - flags = parser.parse_args() width = flags.width -height = width if flags.height is None else flags.height -n_agents = flags.agents -rail_generator, schedule_generator = create_random_railways(project_root, flags.cities) +rail_generator, schedule_generator = create_random_railways(flags.width) # Load in any existing railways for this map size so we don't overwrite them try: - with open(project_root / f'railroads/rail_networks_{n_agents}x{width}x{height}.pkl', 'rb') as file: + with open(project_root / f'railroads/rail_networks_{width}.pkl', 'rb') as file: rail_networks = pickle.load(file) - with open(project_root / f'railroads/schedules_{n_agents}x{width}x{height}.pkl', 'rb') as file: + with open(project_root / f'railroads/schedules_{width}.pkl', 'rb') as file: schedules = pickle.load(file) print(f"Loading {len(rail_networks)} railways...") except: @@ -39,8 +33,8 @@ def do(schedules: list, rail_networks: list): for _ in range(100): - map, info = rail_generator(width, height, n_agents, num_resets=0, np_random=np.random) - schedule = schedule_generator(map, n_agents, info['agents_hints'], num_resets=0, np_random=np.random) + map, info = rail_generator(width, 1, 1, num_resets=0, np_random=np.random) + schedule = schedule_generator(map, 1, info['agents_hints'], num_resets=0, np_random=np.random) rail_networks.append((map, info)) schedules.append(schedule) return @@ -52,9 +46,9 @@ def do(schedules: list, rail_networks: list): # Generate 10000 random railways in 100 batches of 100 for _ in tqdm(range(500), ncols=150, leave=False): do(schedules, rail_networks) - with open(project_root / f'railroads/rail_networks_{n_agents}x{width}x{height}.pkl', 'wb') as file: + with open(project_root / f'railroads/rail_networks_{width}.pkl', 'wb') as file: pickle.dump(schedules, file, protocol=4) - with open(project_root / f'railroads/schedules_{n_agents}x{width}x{height}.pkl', 'wb') as file: + with open(project_root / f'railroads/schedules_{width}.pkl', 'wb') as file: pickle.dump(rail_networks, file, protocol=4) print(f"Saved {len(shared_rail_networks)} railways") diff --git a/src/rail_generators.pyx b/src/rail_generators.pyx new file mode 100644 index 0000000..5f6a98b --- /dev/null +++ b/src/rail_generators.pyx @@ -0,0 +1,285 @@ +"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" +import warnings +from typing import Callable, Tuple, Optional, Dict + +cimport numpy as cnp +import numpy as np +from flatland.core.grid.grid4_utils import direction_to_point +from flatland.core.grid.grid_utils import Vec2dOperations +from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ + fix_inner_nodes, align_cell_to_city +from numpy.random.mtrand import RandomState + +RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] +RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] +cnp.import_array() +# CONSTANTS +cdef bint grid_mode = False +cdef int max_rails_between_cities = 3 +cdef int max_rails_in_city = 4 + +cdef int NORTH = 0 +cdef int EAST = 1 +cdef int SOUTH = 2 +cdef int WEST = 3 + +def generator(int width, int num_agents): + cdef int city_padding = 2 + cdef int max_num_cities = max(2, width ** 2 // 300) + + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=width, transitions=rail_trans) + # We compute the city radius by the given max number of rails it can contain. + # The radius is equal to the number of tracks divided by 2 + # We add 2 cells to avoid that track lenght is to short + # We use ceil if we get uneven numbers of city radius. This is to guarantee that all rails fit within the city. + cdef int city_radius = ((max_rails_in_city + 1) // 2) + city_padding + cdef cnp.ndarray vector_field = np.zeros(shape=(width, width)) - 1. + + + # Calculate the max number of cities allowed + # and reduce the number of cities to build to avoid problems + cdef int max_feasible_cities = min(max_num_cities, ((width - 2) // (2 * (city_radius + 1))) ** 2) + + cdef bint too_close + cdef int col, tries, row + cdef tuple city_pos + cdef list city_positions = [] + cdef int min_distance = (2 * (city_radius + 1) + 1) + cdef int city_idx + + for city_idx in range(max_feasible_cities): + too_close = True + tries = 0 + + while too_close: + row = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) + col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) + too_close = False + # Check distance to cities + for city_pos in city_positions: + if np.abs(row - city_pos[0]) < min_distance and np.abs(col - city_pos[1]) < min_distance: + too_close = True + break + + if not too_close: + city_positions.append((row, col)) + + tries += 1 + if tries > 200: + warnings.warn("Could not set all required cities!") + break + + cdef list inner_connection_points = [] + cdef list outer_connection_points = [] + cdef list city_orientations = [] + cdef list city_cells = [] + cdef list neighb_dist, connection_sides_idx, connection_points_coordinates_outer + cdef list connection_points_coordinates_inner, _city_cells + cdef int current_closest_direction, idx, nr_of_connection_points + cdef int number_of_out_rails, start_idx, direction, connection_idx + cdef tuple neighbour_city, cell + cdef tuple tmp_coordinates = tuple() + cdef out_tmp_coordinates = tuple() + cdef cnp.ndarray connections_per_direction, connection_slots, x_range, y_range, x_values, y_values, inner_point_offset + for city_pos in city_positions: + + # Chose the directions where close cities are situated + neighb_dist = [] + for neighbour_city in city_positions: + neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_pos, neighbour_city)) + closest_neighb_idx = np.argsort(neighb_dist) + + # Store the directions to these neighbours and orient city to face closest neighbour + connection_sides_idx = [] + idx = 1 + current_closest_direction = direction_to_point(city_pos, city_positions[closest_neighb_idx[idx]]) + connection_sides_idx.append(current_closest_direction) + connection_sides_idx.append((current_closest_direction + 2) % 4) + city_orientations.append(current_closest_direction) + x_range = np.arange(city_pos[0] - city_radius, city_pos[0] + city_radius + 1) + y_range = np.arange(city_pos[1] - city_radius, city_pos[1] + city_radius + 1) + x_values = np.repeat(x_range, len(y_range)) + y_values = np.tile(y_range, len(x_range)) + _city_cells = list(zip(x_values, y_values)) + for cell in _city_cells: + vector_field[cell] = align_cell_to_city(city_pos, city_orientations[-1], cell) + city_cells.extend(_city_cells) + # set the number of tracks within a city, at least 2 tracks per city + connections_per_direction = np.zeros(4, dtype=int) + nr_of_connection_points = np.random.randint(2, max_rails_in_city + 1) + for idx in connection_sides_idx: + connections_per_direction[idx] = nr_of_connection_points + connection_points_coordinates_inner = [[] for _ in range(4)] + connection_points_coordinates_outer = [[] for _ in range(4)] + number_of_out_rails = np.random.randint(1, min(max_rails_in_city, nr_of_connection_points) + 1) + start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) + for direction in range(4): + connection_slots = np.arange(nr_of_connection_points) - start_idx + # Offset the rails away from the center of the city + offset_distances = np.arange(nr_of_connection_points) - int(nr_of_connection_points / 2) + # The clipping helps ofsetting one side more than the other to avoid switches at same locations + # The magic number plus one is added such that all points have at least one offset + inner_point_offset = np.abs(offset_distances) + np.clip(offset_distances, 0, 1) + 1 + for connection_idx in range(connections_per_direction[direction]): + if direction == 0: + tmp_coordinates = ( + city_pos[0] - city_radius + inner_point_offset[connection_idx], + city_pos[1] + connection_slots[connection_idx]) + out_tmp_coordinates = ( + city_pos[0] - city_radius, city_pos[1] + connection_slots[connection_idx]) + if direction == 1: + tmp_coordinates = ( + city_pos[0] + connection_slots[connection_idx], + city_pos[1] + city_radius - inner_point_offset[connection_idx]) + out_tmp_coordinates = ( + city_pos[0] + connection_slots[connection_idx], city_pos[1] + city_radius) + if direction == 2: + tmp_coordinates = ( + city_pos[0] + city_radius - inner_point_offset[connection_idx], + city_pos[1] + connection_slots[connection_idx]) + out_tmp_coordinates = ( + city_pos[0] + city_radius, city_pos[1] + connection_slots[connection_idx]) + if direction == 3: + tmp_coordinates = ( + city_pos[0] + connection_slots[connection_idx], + city_pos[1] - city_radius + inner_point_offset[connection_idx]) + out_tmp_coordinates = ( + city_pos[0] + connection_slots[connection_idx], city_pos[1] - city_radius) + connection_points_coordinates_inner[direction].append(tmp_coordinates) + if connection_idx in range(start_idx, start_idx + number_of_out_rails): + connection_points_coordinates_outer[direction].append(out_tmp_coordinates) + + inner_connection_points.append(connection_points_coordinates_inner) + outer_connection_points.append(connection_points_coordinates_outer) + + cdef list inter_city_lines = [] + cdef list city_distances, closest_neighbours + cdef int current_city_idx, direction_to_neighbour, out_direction, neighbour_idx + + for current_city_idx in np.arange(len(city_positions)): + city_distances = [] + closest_neighbours = [None for _ in range(4)] + + # compute distance to all other cities + for city_idx in range(len(city_positions)): + city_distances.append( + Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx])) + sorted_neighbours = np.argsort(city_distances) + + for neighbour in sorted_neighbours[1:]: # do not include city itself + direction_to_neighbour = direction_to_point(city_positions[current_city_idx], city_positions[neighbour]) + if closest_neighbours[direction_to_neighbour] is None: + closest_neighbours[direction_to_neighbour] = neighbour + + # early return once all 4 directions have a closest neighbour + if None not in closest_neighbours: + break + for out_direction in range(4): + if closest_neighbours[out_direction] is not None: + neighbour_idx = closest_neighbours[out_direction] + elif closest_neighbours[(out_direction - 1) % 4] is not None: + neighbour_idx = closest_neighbours[(out_direction - 1) % 4] # counter-clockwise + elif closest_neighbours[(out_direction + 1) % 4] is not None: + neighbour_idx = closest_neighbours[(out_direction + 1) % 4] # clockwise + elif closest_neighbours[(out_direction + 2) % 4] is not None: + neighbour_idx = closest_neighbours[(out_direction + 2) % 4] + + for city_out_connection_point in outer_connection_points[current_city_idx][out_direction]: + + min_connection_dist = np.inf + neighbour_connection_point = None + for direction in range(4): + current_points = outer_connection_points[neighbour_idx][direction] + for tmp_in_connection_point in current_points: + tmp_dist = Vec2dOperations.get_manhattan_distance(city_out_connection_point, + tmp_in_connection_point) + if tmp_dist < min_connection_dist: + min_connection_dist = tmp_dist + neighbour_connection_point = tmp_in_connection_point + + new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point, + rail_trans, flip_start_node_trans=False, + flip_end_node_trans=False, respect_transition_validity=False, + avoid_rail=True, + forbidden_cells=city_cells) + inter_city_lines.extend(new_line) + + # Build inner cities + cdef int i, current_city, opposite_boarder + cdef int boarder = 0 + cdef int track_id, track_nbr + cdef list free_rails = [[] for _ in range(len(city_positions))] + for current_city in range(len(city_positions)): + + # This part only works if we have keep same number of connection points for both directions + # Also only works with two connection direction at each city + for i in range(4): + if len(inner_connection_points[current_city][i]) > 0: + boarder = i + break + + opposite_boarder = (boarder + 2) % 4 + nr_of_connection_points = len(inner_connection_points[current_city][boarder]) + number_of_out_rails = len(outer_connection_points[current_city][boarder]) + start_idx = (nr_of_connection_points - number_of_out_rails) // 2 + # Connect parallel tracks + for track_id in range(nr_of_connection_points): + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans) + free_rails[current_city].append(current_track) + + for track_id in range(nr_of_connection_points): + source = inner_connection_points[current_city][boarder][track_id] + target = inner_connection_points[current_city][opposite_boarder][track_id] + + # Connect parallel tracks with each other + fix_inner_nodes( + grid_map, source, rail_trans) + fix_inner_nodes( + grid_map, target, rail_trans) + + # Connect outer tracks to inner tracks + if start_idx <= track_id < start_idx + number_of_out_rails: + source_outer = outer_connection_points[current_city][boarder][track_id - start_idx] + target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx] + connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans) + connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans) + + # Populate cities + cdef int num_cities = len(city_positions) + cdef list train_stations = [[] for _ in range(num_cities)] + for current_city in range(len(city_positions)): + for track_nbr in range(len(free_rails[current_city])): + possible_location = free_rails[current_city][track_nbr][ + int(len(free_rails[current_city][track_nbr]) / 2)] + train_stations[current_city].append((possible_location, track_nbr)) + + # Fix all transition elements + + cdef cnp.ndarray rails_to_fix = np.zeros(3 * grid_map.height * grid_map.width * 2, dtype='int') + cdef int rails_to_fix_cnt = 0 + cdef list cells_to_fix = city_cells + inter_city_lines + cdef bint cell_valid + for cell in cells_to_fix: + cell_valid = grid_map.cell_neighbours_valid(cell, True) + + if not cell_valid: + rails_to_fix[3 * rails_to_fix_cnt] = cell[0] + rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1] + rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell] + + rails_to_fix_cnt += 1 + # Fix all other cells + for idx in range(rails_to_fix_cnt): + grid_map.fix_transitions((rails_to_fix[3 * idx], rails_to_fix[3 * idx + 1]), rails_to_fix[3 * idx + 2]) + + return grid_map, {'agents_hints': { + 'num_agents': num_agents, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + }} From e72fe0d16a7abc04557b9d62f69aa308c60effd9 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 11:30:45 +0200 Subject: [PATCH 31/75] style(train): remove unused commandline arguments --- src/train.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/train.py b/src/train.py index d0b7940..1009620 100644 --- a/src/train.py +++ b/src/train.py @@ -34,15 +34,12 @@ parser.add_argument("--train", type=boolean, default=True, help="Whether to train the model or just evaluate it") parser.add_argument("--load-model", default=False, action='store_true', help="Whether to load the model from the last checkpoint") -parser.add_argument("--load-railways", type=boolean, default=False, +parser.add_argument("--load-railways", type=boolean, default=True, help="Whether to load in pre-generated railway networks") parser.add_argument("--report-interval", type=int, default=100, help="Iterations between reports") parser.add_argument("--render-interval", type=int, default=0, help="Iterations between renders") # Environment parameters -parser.add_argument("--grid-width", type=int, default=50, help="Number of columns in the environment grid") -parser.add_argument("--grid-height", type=int, default=50, help="Number of rows in the environment grid") -parser.add_argument("--num-agents", type=int, default=5, help="Number of agents in each episode") parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--model-depth", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--hidden-factor", type=int, default=5, help="Depth of the observation tree") @@ -53,7 +50,6 @@ parser.add_argument("--agent-factor", type=float, default=1.1, help="Depth of the observation tree") # Training parameters -parser.add_argument("--agent-type", default="dqn", choices=["dqn", "ppo"], help="Which type of RL agent to use") parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") @@ -90,7 +86,7 @@ rail_generator, schedule_generator = create_random_railways(flags.environment_width, flags.agent_factor) # Create the Flatland environment -environments = [RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, +environments = [RailEnv(width=40, height=40, number_of_agents=1, rail_generator=rail_generator, schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params( From 5537b959dad093688e74b413b0e2449805247307 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 11:31:22 +0200 Subject: [PATCH 32/75] feat(railway-utils): use custom generator --- src/railway_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/railway_utils.py b/src/railway_utils.py index 1b07525..577da54 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -1,13 +1,15 @@ import os import pickle -from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator import numpy as np +from flatland.envs.schedule_generators import sparse_schedule_generator + try: from .agent import BATCH_SIZE + from .rail_generators import generator as rail_generator except: from agent import BATCH_SIZE + from rail_generators import generator as rail_generator class Generator: @@ -47,8 +49,6 @@ def __call__(self, *args, **kwargs): class RailGenerator: def __init__(self, width=35, base=1.5): - self.rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=3, max_rails_between_cities=2, - max_rails_in_city=3) self.sub_idx = 0 self.top_idx = 0 self.width = width @@ -59,7 +59,7 @@ def __next__(self): if self.sub_idx == BATCH_SIZE: self.sub_idx = 0 self.top_idx += 1 - return self.rail_generator(self.width, self.width, int(2 * self.base ** self.top_idx), np_random=np.random) + return rail_generator(self.width, int(2 * self.base ** self.top_idx)) def __call__(self, *args, **kwargs): return next(self) @@ -87,13 +87,13 @@ def __call__(self, rail, _, hints, *args, **kwargs): def load_precomputed_railways(project_root, start_index, big=True): prefix = os.path.join(project_root, 'railroads') if big: - suffix = f'_50x35x20.pkl' + suffix = f'_45x90x90.pkl' else: suffix = f'_3x30x30.pkl' sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) - if big: - sched, rail = rail, sched + # if big: + # sched, rail = rail, sched print(f"Working on {len(rail)} tracks") return rail, sched From 413d571cb455d7341db671d0c221219b2bc53213 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 12:54:57 +0200 Subject: [PATCH 33/75] perf(generate-railways): undo cythonizing --- src/agent.py | 5 +++-- src/generate_railways.py | 1 + src/railway_utils.py | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/agent.py b/src/agent.py index f060c76..e6ef79d 100644 --- a/src/agent.py +++ b/src/agent.py @@ -17,11 +17,12 @@ GAMMA = 0.998 TAU = 1e-3 CLIP_FACTOR = 0.2 -LR = 2e-4 +LR = 4e-5 UPDATE_EVERY = 1 DOUBLE_DQN = False +CUDA = False -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") diff --git a/src/generate_railways.py b/src/generate_railways.py index a535155..5ad11a3 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -6,6 +6,7 @@ import numpy as np from tqdm import tqdm + try: from .railway_utils import create_random_railways except: diff --git a/src/railway_utils.py b/src/railway_utils.py index 577da54..6343e7a 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -2,14 +2,14 @@ import pickle import numpy as np + +from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator try: from .agent import BATCH_SIZE - from .rail_generators import generator as rail_generator except: from agent import BATCH_SIZE - from rail_generators import generator as rail_generator class Generator: @@ -49,6 +49,10 @@ def __call__(self, *args, **kwargs): class RailGenerator: def __init__(self, width=35, base=1.5): + self.rail_generator = sparse_rail_generator(grid_mode=False, + max_num_cities=max(2, width ** 2 // 300), + max_rails_between_cities=2, + max_rails_in_city=3) self.sub_idx = 0 self.top_idx = 0 self.width = width @@ -59,7 +63,7 @@ def __next__(self): if self.sub_idx == BATCH_SIZE: self.sub_idx = 0 self.top_idx += 1 - return rail_generator(self.width, int(2 * self.base ** self.top_idx)) + return self.rail_generator(self.width, self.width, int(2 * self.base ** self.top_idx), np_random=np.random) def __call__(self, *args, **kwargs): return next(self) @@ -92,7 +96,7 @@ def load_precomputed_railways(project_root, start_index, big=True): suffix = f'_3x30x30.pkl' sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) - # if big: + #if big: # sched, rail = rail, sched print(f"Working on {len(rail)} tracks") return rail, sched From a953f13685f0be35f82182dec7af656165aaf816 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 23:41:21 +0200 Subject: [PATCH 34/75] style: remove legacy code --- src/model.py | 163 ++++----------------------------------------------- src/train.py | 4 +- 2 files changed, 15 insertions(+), 152 deletions(-) diff --git a/src/model.py b/src/model.py index a4f2e86..b780255 100644 --- a/src/model.py +++ b/src/model.py @@ -86,97 +86,6 @@ def __repr__(self): return self.str -def try_norm(tensor, norm): - if norm is not None: - tensor = mish(norm(tensor)) - return tensor - - -def make_excite(conv, rank_norm, ranker, linear_norm, linear0, excite_norm, linear1): - @torch.jit.script - def excite(out): - batch = out.size(0) - - exc = conv(out) - - squeeze_heads = exc.size(1) - - exc = torch.nn.functional.softmax(exc, 2) - exc = exc.unsqueeze(-1).transpose(1, -1) - exc = (out.unsqueeze(-1) * exc).sum(2) - - # Rank experts (heads) - hds = exc.view(batch, squeeze_heads, -1) - exc = rank_norm(hds) - exc = ranker(mish(exc)) - exc = exc.softmax(-1) - exc = exc.bmm(hds) - exc = exc.view(batch, -1, 1) - - # Fully-connected block - nrm = linear_norm(exc).squeeze(-1) - nrm = linear0(nrm).unsqueeze(-1) - nrm = excite_norm(nrm) - act = mish(nrm.squeeze(-1)) - exc = linear1(act).tanh() - exc = exc.unsqueeze(-1) - exc = exc.expand_as(out) - - # Merge - out = out * exc - return out - - return excite - - -class Block(torch.nn.Module): - def __init__(self, hidden_size, output_size, bias=False, cat=True, init_norm=False, out_norm=True, - kernel_size=7, squeeze_heads=4): - super().__init__() - self.residual = hidden_size == output_size - self.cat = cat - - self.init_norm = torch.nn.BatchNorm1d(hidden_size) if init_norm else None - self.linr = SeparableConvolution(hidden_size, output_size, kernel_size, padding=kernel_size // 2, bias=bias) - self.out_norm = torch.nn.BatchNorm1d(output_size) if out_norm else None - - self.use_squeeze_attention = squeeze_heads > 0 - - if self.use_squeeze_attention: - self.squeeze_heads = squeeze_heads - self.exc_input_norm = torch.nn.BatchNorm1d(squeeze_heads) - self.expert_ranker = torch.nn.Linear(output_size, squeeze_heads, False) - self.excitation_conv = SeparableConvolution(output_size, squeeze_heads, kernel_size, - padding=kernel_size // 2) - self.linear_in_norm = torch.nn.BatchNorm1d(output_size * squeeze_heads) - self.linear0 = torch.nn.Linear(output_size * squeeze_heads, output_size, False) - self.exc_norm = torch.nn.BatchNorm1d(output_size) - self.linear1 = torch.nn.Linear(output_size, output_size) - self.excite = make_excite(self.excitation_conv, - self.exc_input_norm, - self.expert_ranker, - self.linear_in_norm, - self.linear0, - self.exc_norm, - self.linear1) - - def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - fn_input = try_norm(fn_input, self.init_norm) - out = self.linr(fn_input) - out = try_norm(out, self.out_norm) - - if self.use_squeeze_attention: - out = self.excite(out) - - if self.cat: - return torch.cat([out, fn_input], 1) - - if self.residual: - return out + fn_input - - return out - - class BasicBlock(torch.nn.Module): def __init__(self, in_features, out_features, stride, init_norm=False): super(BasicBlock, self).__init__() @@ -219,6 +128,7 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out = self.linear1(mish(self.mid_norm(self.linear0(mish(self.init_norm(out)))))) return out + def init(module: torch.nn.Module): if hasattr(module, "weight") and hasattr(module.weight, "data"): if "norm" in module.__class__.__name__.lower() or ( @@ -228,76 +138,29 @@ def init(module: torch.nn.Module): torch.nn.init.orthogonal_(module.weight.data) if hasattr(module, "bias") and hasattr(module.bias, "data"): torch.nn.init.constant_(module.bias.data, 0) -# class QNetwork(torch.nn.Module): -# def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, -# debug=True): -# """ -# 11 input features, state_size//11 = item_count -# :param state_size: -# :param action_size: -# :param hidden_factor: -# :param depth: -# :return: -# """ -# super(QNetwork, self).__init__() -# observations = state_size // 11 -# if debug: -# print(f"[DEBUG/MODEL] Using {observations} observations as input") -# -# out_features = hidden_factor * 11 -# -# net = torch.nn.ModuleList([torch.nn.Conv1d(state_size, out_features, 1), -# *[Block(out_features + out_features * i * cat, -# out_features, -# cat=cat, -# init_norm=not i, -# kernel_size=kernel_size, -# squeeze_heads=squeeze_heads) -# for i in range(depth)], -# Block(out_features + out_features * depth * cat, action_size, -# bias=True, -# cat=False, -# out_norm=False, -# init_norm=False, -# kernel_size=kernel_size, -# squeeze_heads=squeeze_heads)]) -# -# -# net.apply(init) -# -# if debug: -# parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, net.parameters())) -# digits = int(math.log10(parameters)) -# number_string = " kMGTPEZY"[digits // 3] -# -# print( -# f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") -# -# self.net = net -# -# def forward(self, fn_input: torch.Tensor) -> typing.Tuple[torch.Tensor, torch.Tensor]: -# out = fn_input -# for module in self.net: -# out = module(out) -# return out class Residual(torch.nn.Module): def __init__(self, features): super(Residual, self).__init__() self.norm = torch.nn.BatchNorm1d(features) - self.conv = WeightDropConv(features, features) + self.conv = WeightDropConv(features, 2 * features) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - return self.conv(mish(self.norm(fn_input))) + fn_input + out, exc = self.conv(mish(self.norm(fn_input))).chunk(2, 1) + exc = exc.mean(dim=1, keepdim=True).tanh() + fn_input = fn_input * exc + out = -out * exc + return fn_input + out -def QNetwork(state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=False, +def QNetwork(state_size, action_size, hidden_factor=16, depth=4, kernel_size=7, squeeze_heads=4, cat=False, debug=True): - model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 55, 1, groups=11, bias=False), - Residual(55), - torch.nn.BatchNorm1d(55), + model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 11 * hidden_factor, 1, groups=11, bias=False), + Residual(11 * hidden_factor), + torch.nn.BatchNorm1d(11 * hidden_factor), Mish(), - WeightDropConv(55, action_size, 1)) + WeightDropConv(11 * hidden_factor, action_size, 1)) + print(model) if debug: parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())) digits = int(math.log10(parameters)) diff --git a/src/train.py b/src/train.py index 1009620..727fd31 100644 --- a/src/train.py +++ b/src/train.py @@ -40,9 +40,9 @@ parser.add_argument("--render-interval", type=int, default=0, help="Iterations between renders") # Environment parameters -parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--tree-depth", type=int, default=3, help="Depth of the observation tree") parser.add_argument("--model-depth", type=int, default=1, help="Depth of the observation tree") -parser.add_argument("--hidden-factor", type=int, default=5, help="Depth of the observation tree") +parser.add_argument("--hidden-factor", type=int, default=16, help="Depth of the observation tree") parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") From 1cbabcebec54af7607537a7a395d06b2cbbf1e83 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 13 Jul 2020 23:43:03 +0200 Subject: [PATCH 35/75] fix(model): message box for agents (not features) --- src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index b780255..cd3d9d3 100644 --- a/src/model.py +++ b/src/model.py @@ -147,7 +147,7 @@ def __init__(self, features): def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out, exc = self.conv(mish(self.norm(fn_input))).chunk(2, 1) - exc = exc.mean(dim=1, keepdim=True).tanh() + exc = exc.mean(dim=-1, keepdim=True).tanh() fn_input = fn_input * exc out = -out * exc return fn_input + out From e9cf607f517b8b4ec80b8bfaf46204a97b72f845 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 17 Jul 2020 04:24:11 +0200 Subject: [PATCH 36/75] feat: add global, local env --- checkpoints/ppo/README.md | 1 - src/agent.py | 9 ++- src/model.py | 54 ++++++++++----- src/observation_utils.pyx | 137 +++++++++++++++++++++++++++++++++++++- src/railway_utils.py | 6 +- src/train.py | 71 +++++++++++++------- 6 files changed, 231 insertions(+), 47 deletions(-) delete mode 100644 checkpoints/ppo/README.md diff --git a/checkpoints/ppo/README.md b/checkpoints/ppo/README.md deleted file mode 100644 index 700b9ec..0000000 --- a/checkpoints/ppo/README.md +++ /dev/null @@ -1 +0,0 @@ -PPO checkpoints will be saved here diff --git a/src/agent.py b/src/agent.py index e6ef79d..62221ff 100644 --- a/src/agent.py +++ b/src/agent.py @@ -13,7 +13,7 @@ import os BUFFER_SIZE = 500_000 -BATCH_SIZE = 256 +BATCH_SIZE = 32 GAMMA = 0.998 TAU = 1e-3 CLIP_FACTOR = 0.2 @@ -50,6 +50,13 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s kernel_size, squeeze_heads, debug=False).to(device) + try: + self.policy = torch.jit.script(self.policy) + self.old_policy = torch.jit.script(self.old_policy) + except: + import traceback + traceback.print_exc() + print("NO JIT") self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) diff --git a/src/model.py b/src/model.py index cd3d9d3..073eac2 100644 --- a/src/model.py +++ b/src/model.py @@ -8,7 +8,9 @@ @torch.jit.script def mish(fn_input: torch.Tensor) -> torch.Tensor: return fn_input * torch.tanh(torch.nn.functional.softplus(fn_input)) - +@torch.jit.script +def nothing(x): + return x class Mish(torch.nn.Module): def forward(self, fn_input: torch.Tensor) -> torch.Tensor: @@ -23,7 +25,8 @@ class WeightDropConv(torch.nn.Module): weight_dropout (float): The probability a weight will be dropped. """ - def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True, weight_dropout=0.1, groups=1, + def __init__(self, in_features: int, out_features: int, kernel_size: typing.Union[int, tuple] = 1, bias=True, + weight_dropout=0.1, groups=1, padding=0, dilation=1, function=torch.nn.functional.conv1d, stride=1): super().__init__() self.weight_dropout = weight_dropout @@ -31,7 +34,10 @@ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True print(f"[ERROR] Unable to get weight for in={in_features},groups={groups}. Make sure they are divisible.") if out_features % groups != 0: print(f"[ERROR] Unable to get weight for out={out_features},groups={groups}. Make sure they are divisible.") - self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features // groups, kernel_size)) + if isinstance(kernel_size, int): + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features // groups, kernel_size)) + else: + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features // groups, *kernel_size)) if bias: self.bias = torch.nn.Parameter(torch.Tensor(out_features)) else: @@ -43,8 +49,11 @@ def __init__(self, in_features: int, out_features: int, kernel_size=1, bias=True self._function = function def forward(self, fn_input): + drop = torch.nn.functional.dropout(self.weight, self.weight_dropout, self.training) + if drop.dtype != self.weight.dtype: + drop = drop.to(self.weight.dtype) return self._function(fn_input, - torch.nn.functional.dropout(self.weight, self.weight_dropout, self.training), + drop, bias=self.bias, padding=self.padding, dilation=self.dilation, @@ -57,9 +66,11 @@ def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tup padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, bias=False, dim=1, stride=1): super(SeparableConvolution, self).__init__() - self.depthwise = kernel_size > 1 + self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else all(k>1 for k in kernel_size) function = getattr(torch.nn.functional, f'conv{dim}d') norm = getattr(torch.nn, f'BatchNorm{dim}d') + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * dim if self.depthwise: self.depthwise_conv = WeightDropConv(in_features, in_features, kernel_size, @@ -70,7 +81,10 @@ def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tup function=function, stride=stride) self.mid_norm = norm(in_features) - self.pointwise_conv = WeightDropConv(in_features, out_features, 1, bias=bias, function=function) + else: + self.depthwise_conv = nothing + self.mid_norm = nothing + self.pointwise_conv = WeightDropConv(in_features, out_features, (1,)*dim , bias=bias, function=function) self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + f'dilation={dilation}, padding={padding})') @@ -89,13 +103,15 @@ def __repr__(self): class BasicBlock(torch.nn.Module): def __init__(self, in_features, out_features, stride, init_norm=False): super(BasicBlock, self).__init__() - self.init_norm = torch.nn.BatchNorm2d(out_features) if init_norm else None - self.init_conv = SeparableConvolution(in_features, out_features, 3, 1, stride=stride, dim=2) - self.mid_norm = torch.nn.BatchNorm2d(out_features) - self.end_conv = SeparableConvolution(in_features, out_features, 3, 1, dim=2) + self.init_norm = torch.nn.BatchNorm3d(in_features) if init_norm else None + self.init_conv = SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), + stride=(stride, stride, 1), dim=3) + self.mid_norm = torch.nn.BatchNorm3d(out_features) + self.end_conv = SeparableConvolution(out_features, out_features, (3, 3, 1), (1, 1, 0), dim=3) self.shortcut = (None if stride == 1 and in_features == out_features - else SeparableConvolution(in_features, out_features, 3, 1, stride=stride, dim=2)) + else SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), + stride=(stride, stride, 1), dim=3)) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out = self.init_conv(fn_input if self.init_norm is None else mish(self.init_norm(fn_input))) @@ -108,17 +124,18 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: class ConvNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_factor=15, depth=4, kernel_size=7, squeeze_heads=4, cat=True, + def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, cat=True, debug=True): super(ConvNetwork, self).__init__() - hidden_size = 11 * hidden_factor - self.net = torch.nn.ModuleList([BasicBlock(state_size, hidden_size, 1), - *[BasicBlock(hidden_size, hidden_size, 2 - i % 2, True) - for i in range(depth)]]) + _ = state_size + state_size = 2*21 + self.net = torch.nn.ModuleList([BasicBlock(state_size if not i else hidden_size, hidden_size, 2, True) + for i in range(depth)]) self.init_norm = torch.nn.BatchNorm1d(hidden_size) - self.linear0 = torch.nn.Linear(hidden_size, hidden_size, bias=False) + self.linear0 = torch.nn.Conv1d(hidden_size, hidden_size, 1, bias=False) self.mid_norm = torch.nn.BatchNorm1d(hidden_size) - self.linear1 = torch.nn.Linear(hidden_size, action_size) + self.linear1 = torch.nn.Conv1d(hidden_size, action_size, 1) + print(self) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out = fn_input @@ -139,6 +156,7 @@ def init(module: torch.nn.Module): if hasattr(module, "bias") and hasattr(module.bias, "data"): torch.nn.init.constant_(module.bias.data, 0) + class Residual(torch.nn.Module): def __init__(self, features): super(Residual, self).__init__() diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index 333a238..abd7256 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -1,5 +1,6 @@ from collections import defaultdict - +cimport numpy as cnp +import numpy as np import torch from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position @@ -45,6 +46,140 @@ cdef class RailNode: return f'RailNode({self.position}, {len(self.edges)})' +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16),\ + assuming 16 bits encoding of transitions. + + - obs_agents_state: A 3D array (map_height, map_width, 5) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and direction + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - fifth channel containing number of other agents ready to depart + + - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ + target and the positions of the other agents targets (flag only, no counter!). + """ + def __init__(self): + super(GlobalObsForRailEnv, self).__init__() + self.size = 0 + self._custom_rail_obs = None + def reset(self): + if self._custom_rail_obs is None: + self._custom_rail_obs = np.zeros((1, self.env.height + 2*self.size, self.env.width + 2*self.size, 16)) + + self._custom_rail_obs[0, self.size:-self.size, self.size:-self.size] = np.array([[[[1 if digit == '1' else 0 + for digit in + f'{self.env.rail.get_full_transitions(i, j):016b}'] + for j in range(self.env.width)] + for i in range(self.env.height)]], + dtype=np.float32) + + def get_many(self, list trash): + cdef int agent_count = len(self.env.agents) + cdef cnp.ndarray obs_agents_state = np.zeros((agent_count, + self.env.height, + self.env.width, + 5), dtype=np.float32) + cdef int i, agent_id + cdef tuple pos, agent_virtual_position + for agent_id, agent in enumerate(self.env.agents): + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + continue + + obs_agents_state[agent_id, :, :, 0:4] = -1 + + obs_agents_state[(agent_id,) + agent_virtual_position + (0,)] = agent.direction + + for i, other_agent in enumerate(self.env.agents): + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + # second to fourth channel only if in the grid + if other_agent.position is not None: + pos = (agent_id,) + other_agent.position + # second channel only for other agents + if i != agent_id: + obs_agents_state[pos + (1,)] = other_agent.direction + obs_agents_state[pos + (2,)] = other_agent.malfunction_data['malfunction'] + obs_agents_state[pos + (3,)] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + obs_agents_state[(agent_id,) + other_agent.initial_position + (4,)] += 1 + return {i: arr + for i, arr in + enumerate(np.concatenate([np.repeat(self.rail_obs, agent_count, 0), obs_agents_state], -1))} + + +class LocalObsForRailEnv(GlobalObsForRailEnv): + def __init__(self, size=7): + super(LocalObsForRailEnv, self).__init__() + self.size = size + def get_many(self, list trash): + cdef int agent_count = len(self.env.agents) + obs_agents_state = np.zeros((agent_count, + self.size * 2 + 1, + self.size * 2 + 1, + 21), dtype=np.float32) + cdef int i, agent_id + cdef tuple agent_virtual_position + for agent_id, agent in enumerate(self.env.agents): + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + continue + x0, y0, x1, y1 = (agent_virtual_position[0], + agent_virtual_position[1], + agent_virtual_position[0] + 2*self.size + 1, + agent_virtual_position[1] + 2*self.size + 1) + obs_agents_state[agent_id, :, :, 5:] = self._custom_rail_obs[0, x0:x1, y0:y1] + + obs_agents_state[agent_id, :, :, 0:4] = -1 + + obs_agents_state[agent_id, :, :, 0] = agent.direction + + for i, other_agent in enumerate(self.env.agents): + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + # second to fourth channel only if in the grid + if other_agent.position is not None: + pos = (agent_id,) + other_agent.position + # second channel only for other agents + if i != agent_id: + obs_agents_state[agent_id, :, :, 1] = other_agent.direction + obs_agents_state[agent_id, :, :, 2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[agent_id, :, :, 3] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + init_pos = other_agent.initial_position + dist0 = agent_virtual_position[0] - init_pos[0] + dist1 = agent_virtual_position[1] - init_pos[1] + if abs(dist0) < self.size and abs(dist1) < self.size: + obs_agents_state[agent_id, dist0 + self.size, dist1 + self.size, 4] += 1 + return {i: arr + for i, arr in + enumerate(obs_agents_state)} + + class TreeObservation(ObservationBuilder): def __init__(self, max_depth): super().__init__() diff --git a/src/railway_utils.py b/src/railway_utils.py index 6343e7a..74df209 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -88,12 +88,12 @@ def __call__(self, rail, _, hints, *args, **kwargs): # Helper function to load in precomputed railway networks -def load_precomputed_railways(project_root, start_index, big=True): +def load_precomputed_railways(project_root, start_index, big=False): prefix = os.path.join(project_root, 'railroads') if big: - suffix = f'_45x90x90.pkl' + suffix = f'_110.pkl' else: - suffix = f'_3x30x30.pkl' + suffix = f'_50.pkl' sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) #if big: diff --git a/src/train.py b/src/train.py index 727fd31..dbc105d 100644 --- a/src/train.py +++ b/src/train.py @@ -17,13 +17,13 @@ from .rail_env import RailEnv from .agent import Agent as DQN_Agent, device, BATCH_SIZE from .normalize_output_data import wrap - from .observation_utils import normalize_observation, TreeObservation + from .observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv from .railway_utils import load_precomputed_railways, create_random_railways except: from rail_env import RailEnv from agent import Agent as DQN_Agent, device, BATCH_SIZE from normalize_output_data import wrap - from observation_utils import normalize_observation, TreeObservation + from observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv from railway_utils import load_precomputed_railways, create_random_railways project_root = Path(__file__).resolve().parent.parent @@ -41,7 +41,7 @@ # Environment parameters parser.add_argument("--tree-depth", type=int, default=3, help="Depth of the observation tree") -parser.add_argument("--model-depth", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--model-depth", type=int, default=3, help="Depth of the observation tree") parser.add_argument("--hidden-factor", type=int, default=16, help="Depth of the observation tree") parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") @@ -54,11 +54,15 @@ parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") parser.add_argument("--collision-reward", type=float, default=-2, help="Depth of the observation tree") -parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") +parser.add_argument("--global-environment", type=boolean, default=True, help="Depth of the observation tree") +parser.add_argument("--local-environment", type=boolean, default=True, help="Depth of the observation tree") parser.add_argument("--threads", type=int, default=1, help="Depth of the observation tree") flags = parser.parse_args() +if flags.local_environment: + flags.global_environment = True + # Seeded RNG so we can replicate our results # Create a tensorboard SummaryWriter @@ -91,7 +95,9 @@ schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params( MalfunctionParameters(1 / 500, 20, 50)), - obs_builder_object=(GlobalObsForRailEnv() + obs_builder_object=((LocalObsForRailEnv(4) + if flags.local_environment + else GlobalObsForRailEnv) if flags.global_environment else TreeObservation(max_depth=flags.tree_depth)), random_seed=i) @@ -112,19 +118,27 @@ # Helper function to detect collisions ACTIONS = {0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S'} - -def is_collision(a, i): - if obs[i][a] is None: return False - is_junction = not isinstance(obs[i][a].childs['L'], float) or not isinstance(obs[i][a].childs['R'], float) - - if not is_junction or environments[i].agents[a].speed_data['position_fraction'] > 0: - action = ACTIONS[environments[i].agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' - return obs[i][a].childs[action] != negative_infinity and obs[i][a].childs[action] != positive_infinity \ - and obs[i][a].childs[action].num_agents_opposite_direction > 0 \ - and obs[i][a].childs[action].dist_other_agent_encountered <= 1 \ - and obs[i][a].childs[action].dist_other_agent_encountered < obs[i][a].childs[action].dist_unusable_switch - else: - return False +if flags.global_environment: + def is_collision(a, i): + own_agent = environments[i].agents[a] + return any(own_agent.position == agent.position + for agent_id, agent in enumerate(environments[i].agents) + if agent_id != a) +else: + def is_collision(a, i): + if obs[i][a] is None: return False + is_junction = not isinstance(obs[i][a].childs['L'], float) or not isinstance(obs[i][a].childs['R'], float) + + if not is_junction or environments[i].agents[a].speed_data['position_fraction'] > 0: + action = ACTIONS[ + environments[i].agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' + return obs[i][a].childs[action] != negative_infinity and obs[i][a].childs[action] != positive_infinity \ + and obs[i][a].childs[action].num_agents_opposite_direction > 0 \ + and obs[i][a].childs[action].dist_other_agent_encountered <= 1 \ + and obs[i][a].childs[action].dist_other_agent_encountered < obs[i][a].childs[ + action].dist_unusable_switch + else: + return False def get_means(x, y, c, s): @@ -159,8 +173,12 @@ def normalize(observation, target_tensor): score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) - agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) - normalize(obs, agent_obs) + if flags.global_environment: + agent_obs = torch.as_tensor([list(o.values()) for o in obs]).float().to(device) + else: + agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) + normalize(obs, agent_obs) + agent_obs_buffer = agent_obs.clone() agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] @@ -169,7 +187,11 @@ def normalize(observation, target_tensor): for step in range(max_steps): update_values = [[False] * agent_count for _ in range(BATCH_SIZE)] action_dict = [{} for _ in range(BATCH_SIZE)] - input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) + if flags.global_environment: + input_tensor = torch.cat([agent_obs_buffer, agent_obs], -1) + input_tensor.transpose_(1, -1) + else: + input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) if any(any(inf['action_required']) for inf in info): ret_action = agent.multi_act(input_tensor, eps=eps) else: @@ -210,7 +232,10 @@ def normalize(observation, target_tensor): if all_done: break - normalize(obs, agent_obs) + if flags.global_environment: + agent_obs = torch.as_tensor([list(o.values()) for o in obs]).float().to(device) + else: + normalize(obs, agent_obs) # Render # if flags.render_interval and episode % flags.render_interval == 0: @@ -229,7 +254,7 @@ def normalize(observation, target_tensor): current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6} - Agents: {agent_count:>3}' - f' | Score: {current_score:.4f}, {mean_score:.4f}' + f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' From 59e385e7f3a5adf7713c521d1141d5ccc634b124 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 17 Jul 2020 23:25:15 +0200 Subject: [PATCH 37/75] feat: re-enable cuda (after reducing input size) --- src/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agent.py b/src/agent.py index 62221ff..c40675b 100644 --- a/src/agent.py +++ b/src/agent.py @@ -13,14 +13,14 @@ import os BUFFER_SIZE = 500_000 -BATCH_SIZE = 32 +BATCH_SIZE = 256 GAMMA = 0.998 TAU = 1e-3 CLIP_FACTOR = 0.2 LR = 4e-5 UPDATE_EVERY = 1 DOUBLE_DQN = False -CUDA = False +CUDA = True device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") From 36c6e4a596ecf9141fc960524ebf6d812f644d58 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 17 Jul 2020 23:27:50 +0200 Subject: [PATCH 38/75] feat: add readme to keep folder structure --- checkpoints/dqn/dqn0/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 checkpoints/dqn/dqn0/README.md diff --git a/checkpoints/dqn/dqn0/README.md b/checkpoints/dqn/dqn0/README.md new file mode 100644 index 0000000..7792019 --- /dev/null +++ b/checkpoints/dqn/dqn0/README.md @@ -0,0 +1 @@ +DQN checkpoints will be saved here From af2ec5c7b33a11571ad102de54791b0874d09d47 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 17 Jul 2020 23:30:08 +0200 Subject: [PATCH 39/75] fix: add garbage to gitignore --- .gitignore | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.gitignore b/.gitignore index 2a88583..6f4d625 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,24 @@ venv/ .DS_Store *.pyc +/src/a.out +/index.html +/checkpoints/dqn/dqn0/loss.txt +/.idea/** +**model_checkpoint* +/.idea/modules.xml +/src/obs.so +/src/observation_utils.c +/src/observation_utils.h +/src/observation_utils.so +/.idea/other.xml +/.idea/inspectionProfiles/profiles_settings.xml +/.idea/inspectionProfiles/Project_Default.xml +/r/rail_networks_35x40x40.pkl +/r/rail_networks_sum.pkl +/r/schedules_35x40x40.pkl +/r/schedules_sum.pkl +/.idea/vcs.xml +*.c +*.so +*.out \ No newline at end of file From 9c4e4b358e801ed8baeb83c8791e9cf911dbcbcc Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 18 Jul 2020 17:54:59 +0200 Subject: [PATCH 40/75] style: major cleanup --- README.md | 77 ++++++++--- install.sh | 16 +++ requirements.txt | 15 +-- src/agent.py | 3 - src/generate_railways.py | 6 +- src/rail_generators.pyx | 285 --------------------------------------- src/replay_memory.py | 66 --------- src/tree_observation.py | 0 8 files changed, 81 insertions(+), 387 deletions(-) create mode 100755 install.sh delete mode 100644 src/rail_generators.pyx delete mode 100644 src/replay_memory.py delete mode 100644 src/tree_observation.py diff --git a/README.md b/README.md index 8694333..a494a6d 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,70 @@ # flatland-training -This repo contains an optimized version of flatland-rl's `flatland.envs.observations.TreeObsForRailEnv`. Tree-based observations allow RL models to learn much more quickly than the global observations do, but flatland's built-in TreeObsForRailEnv is kind of slow, so I wrote a faster version! This repo also contains an optimized version of [https://gitlab.aicrowd.com/flatland/baselines/blob/master/utils/observation_utils.py](https://gitlab.aicrowd.com/flatland/baselines/blob/master/utils/observation_utils.py), which flattens and normalizes the tree observations into 1D numpy arrays that can be passed to a feed-forward network. +PyTorch solution for [flatland-2020](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/) +## Implementation -# Setup -## Create venv -`python3.7 -m venv venv`\ -`source venv/bin/activate` +This repository contains three major modules. -Verify python version is correct with: `python -V`\ -Should return `Python 3.7.something` +### Getting Started -## Install Requirements -`pip install -r requirements.txt` +#### Setup +Before following along here, please note that there is an `install.sh` script which executes all the commands here.\ +First, create virtual environment using `python3.7 -m venv venv && source venv/bin/activate`.\ +Then install the requirements with `python3 -m pip install -r requirements.txt`.\ +It is recommended to verify that the installation was successful by first checking the python version and then attempting an import of all required non-standard packages. +```bash +$ python3 --version +Python 3.7.6 +$ python3 -c "import torch, torch_optimizer, numpy, cython, flatland, gym, tqdm; print('Successfully imported packages')" +Successfully imported packages +``` +Lastly perhaps the most crucial step. It requires `gcc-7`, as no other version works. On Debian/Ubuntu, it can be installed using the apt package manager by running `apt install gcc-7`.\ +Once that's done, the python code can be compiled using cython. Compilation is done by first moving into the source folder and then executing cythonize.sh via `cd src && bash cythonize.sh`. -# Generate Railways -This script will precompute a bunch of railway maps to make training faster.\ +#### Generate Environments -`python src/generate_railways.py` +For better training performance, one can optionally generate the environments used to train the network _before_ training it. This way training is much faster as training data doesn't have to be regenerated repeatedly but instead gets loaded once on startup.\ +To start the generation of environments and their respective railways, use `python3 src/generate_railways.py --width 50`. The command will generate 50x50 grids of cities, rails and trains. -This will run for quite a long time, go get some tea...\ -But also it's fine to stop it after it completes at least one round if you just want to test things out and make sure they run.\ -If you don't care about the speedup, you can run `python src/train.py --load-railways=False` to generate railways on the fly during training instead. +#### Run Training +Finally, it's time to train the model. You can do so by running `python3 src/train.py`, which will train a basic cnn using the "local observation" method. ![https://flatland.aicrowd.com/getting-started/env.html](https://i.imgur.com/oo8EIYv.png)\ +Not only global and tree observation, but also many model parameters are implemented as well. To find out more about them, add the `--help` flag. -# Run Training -`python src/train.py` +### Structure -This will begin training one or more agents in the flatland environment.\ -This file has a lot of parameters that can be set to do different things.\ -To see all the options, use the `--help` command line argument. +Currently the code is structured in huge monolithic files. + +| Name | Description | +|----|----| +| agent.py | Reward and trainings algorithm, as well as some hyper parameters (such as batch size and learning rate)| +| generate_railways.py | Script to pre-compute and generate railways from the command line. See [#Generate Environments](#Generate-Environments)| +| model.py | PyTorch definition of tree-observation and local-observation models| +| observation_utils.pyx | Agent observation utilities called by environment to create training observation | +| rail_env.pyx | Cython-port of flatland-rl RailEnv | +| railway_utils.py | Utility script to handle creation of and iterators over railways | +| train.py | Core training loop | + +### Future Work + +The current implementation has many holes. One of them is the very poor performance received when controlling many (>10) agents at once.\ +We tackle this issue from multiple sites at once. If you would like to participate in this team, open an issue, pull request or join us on [discord](https://discord.gg/mP72wbE). +Our current approaches are listed below: + +* **Observation, Model**: + * Tree observation, graph neural networks + * Tree observation, fully-connected networks + * Tree observation, transformer + * Local observation, cnn + * Global observation, cnn +* **Teaching algorithm**: + * PPO + * (Double-) DQN +* **Misc. Freebies**: + * Epsilon-Greedy + * Multiprocessing + * Inter-agent communication + +If you are working on one of these tasks or would like to do so, please open an issue or pull request to let others know about it. Once seen, it will be added to the main repository. \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..7b752d2 --- /dev/null +++ b/install.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +function check_python { +python3 -c "import torch, torch_optimizer, numpy, cython, flatland, gym, tqdm; print('Successfully imported packages')" 2>/dev/null\ +|| (python3.7 -m venv venv && source venv/bin/activate && python3 -m pip install -r requirements.txt) || check_python +} + +check_python + +echo "Checking for GCC-7" + +gcc-7 --version > /dev/null || sudo apt install gcc-7 + +echo "Compiling source" +cd src && source cythonize.sh > /dev/null 2>/dev/null +cd .. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2fc4cc6..8c292f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ -argparse==1.4.0 -flatland-rl==2.2.1 -numpy==1.18.5 -torch==1.5.0 -opencv-python==4.2.0.34 -Pillow==7.1.2 -tqdm==4.46.1 -tensorboardX==2.0 +torch +torch-optimizer +numpy +cython +flatland-rl +gym +tqdm \ No newline at end of file diff --git a/src/agent.py b/src/agent.py index c40675b..f720358 100644 --- a/src/agent.py +++ b/src/agent.py @@ -6,10 +6,8 @@ try: from .model import QNetwork, ConvNetwork - from .replay_memory import ReplayBuffer except: from model import QNetwork, ConvNetwork - from replay_memory import ReplayBuffer import os BUFFER_SIZE = 500_000 @@ -61,7 +59,6 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) # Replay memory - self.memory = ReplayBuffer(BATCH_SIZE) self.stack = [[] for _ in range(4)] self.t_step = 0 diff --git a/src/generate_railways.py b/src/generate_railways.py index 5ad11a3..c65e0f6 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -41,10 +41,6 @@ def do(schedules: list, rail_networks: list): return -manager = multiprocessing.Manager() -shared_schedules = manager.list(schedules) -shared_rail_networks = manager.list(rail_networks) -# Generate 10000 random railways in 100 batches of 100 for _ in tqdm(range(500), ncols=150, leave=False): do(schedules, rail_networks) with open(project_root / f'railroads/rail_networks_{width}.pkl', 'wb') as file: @@ -52,5 +48,5 @@ def do(schedules: list, rail_networks: list): with open(project_root / f'railroads/schedules_{width}.pkl', 'wb') as file: pickle.dump(rail_networks, file, protocol=4) -print(f"Saved {len(shared_rail_networks)} railways") +print(f"Saved {len(rail_networks)} railways") print("Done") diff --git a/src/rail_generators.pyx b/src/rail_generators.pyx deleted file mode 100644 index 5f6a98b..0000000 --- a/src/rail_generators.pyx +++ /dev/null @@ -1,285 +0,0 @@ -"""Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" -import warnings -from typing import Callable, Tuple, Optional, Dict - -cimport numpy as cnp -import numpy as np -from flatland.core.grid.grid4_utils import direction_to_point -from flatland.core.grid.grid_utils import Vec2dOperations -from flatland.core.grid.rail_env_grid import RailEnvTransitions -from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ - fix_inner_nodes, align_cell_to_city -from numpy.random.mtrand import RandomState - -RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] -RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] -cnp.import_array() -# CONSTANTS -cdef bint grid_mode = False -cdef int max_rails_between_cities = 3 -cdef int max_rails_in_city = 4 - -cdef int NORTH = 0 -cdef int EAST = 1 -cdef int SOUTH = 2 -cdef int WEST = 3 - -def generator(int width, int num_agents): - cdef int city_padding = 2 - cdef int max_num_cities = max(2, width ** 2 // 300) - - rail_trans = RailEnvTransitions() - grid_map = GridTransitionMap(width=width, height=width, transitions=rail_trans) - # We compute the city radius by the given max number of rails it can contain. - # The radius is equal to the number of tracks divided by 2 - # We add 2 cells to avoid that track lenght is to short - # We use ceil if we get uneven numbers of city radius. This is to guarantee that all rails fit within the city. - cdef int city_radius = ((max_rails_in_city + 1) // 2) + city_padding - cdef cnp.ndarray vector_field = np.zeros(shape=(width, width)) - 1. - - - # Calculate the max number of cities allowed - # and reduce the number of cities to build to avoid problems - cdef int max_feasible_cities = min(max_num_cities, ((width - 2) // (2 * (city_radius + 1))) ** 2) - - cdef bint too_close - cdef int col, tries, row - cdef tuple city_pos - cdef list city_positions = [] - cdef int min_distance = (2 * (city_radius + 1) + 1) - cdef int city_idx - - for city_idx in range(max_feasible_cities): - too_close = True - tries = 0 - - while too_close: - row = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) - col = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1)) - too_close = False - # Check distance to cities - for city_pos in city_positions: - if np.abs(row - city_pos[0]) < min_distance and np.abs(col - city_pos[1]) < min_distance: - too_close = True - break - - if not too_close: - city_positions.append((row, col)) - - tries += 1 - if tries > 200: - warnings.warn("Could not set all required cities!") - break - - cdef list inner_connection_points = [] - cdef list outer_connection_points = [] - cdef list city_orientations = [] - cdef list city_cells = [] - cdef list neighb_dist, connection_sides_idx, connection_points_coordinates_outer - cdef list connection_points_coordinates_inner, _city_cells - cdef int current_closest_direction, idx, nr_of_connection_points - cdef int number_of_out_rails, start_idx, direction, connection_idx - cdef tuple neighbour_city, cell - cdef tuple tmp_coordinates = tuple() - cdef out_tmp_coordinates = tuple() - cdef cnp.ndarray connections_per_direction, connection_slots, x_range, y_range, x_values, y_values, inner_point_offset - for city_pos in city_positions: - - # Chose the directions where close cities are situated - neighb_dist = [] - for neighbour_city in city_positions: - neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_pos, neighbour_city)) - closest_neighb_idx = np.argsort(neighb_dist) - - # Store the directions to these neighbours and orient city to face closest neighbour - connection_sides_idx = [] - idx = 1 - current_closest_direction = direction_to_point(city_pos, city_positions[closest_neighb_idx[idx]]) - connection_sides_idx.append(current_closest_direction) - connection_sides_idx.append((current_closest_direction + 2) % 4) - city_orientations.append(current_closest_direction) - x_range = np.arange(city_pos[0] - city_radius, city_pos[0] + city_radius + 1) - y_range = np.arange(city_pos[1] - city_radius, city_pos[1] + city_radius + 1) - x_values = np.repeat(x_range, len(y_range)) - y_values = np.tile(y_range, len(x_range)) - _city_cells = list(zip(x_values, y_values)) - for cell in _city_cells: - vector_field[cell] = align_cell_to_city(city_pos, city_orientations[-1], cell) - city_cells.extend(_city_cells) - # set the number of tracks within a city, at least 2 tracks per city - connections_per_direction = np.zeros(4, dtype=int) - nr_of_connection_points = np.random.randint(2, max_rails_in_city + 1) - for idx in connection_sides_idx: - connections_per_direction[idx] = nr_of_connection_points - connection_points_coordinates_inner = [[] for _ in range(4)] - connection_points_coordinates_outer = [[] for _ in range(4)] - number_of_out_rails = np.random.randint(1, min(max_rails_in_city, nr_of_connection_points) + 1) - start_idx = int((nr_of_connection_points - number_of_out_rails) / 2) - for direction in range(4): - connection_slots = np.arange(nr_of_connection_points) - start_idx - # Offset the rails away from the center of the city - offset_distances = np.arange(nr_of_connection_points) - int(nr_of_connection_points / 2) - # The clipping helps ofsetting one side more than the other to avoid switches at same locations - # The magic number plus one is added such that all points have at least one offset - inner_point_offset = np.abs(offset_distances) + np.clip(offset_distances, 0, 1) + 1 - for connection_idx in range(connections_per_direction[direction]): - if direction == 0: - tmp_coordinates = ( - city_pos[0] - city_radius + inner_point_offset[connection_idx], - city_pos[1] + connection_slots[connection_idx]) - out_tmp_coordinates = ( - city_pos[0] - city_radius, city_pos[1] + connection_slots[connection_idx]) - if direction == 1: - tmp_coordinates = ( - city_pos[0] + connection_slots[connection_idx], - city_pos[1] + city_radius - inner_point_offset[connection_idx]) - out_tmp_coordinates = ( - city_pos[0] + connection_slots[connection_idx], city_pos[1] + city_radius) - if direction == 2: - tmp_coordinates = ( - city_pos[0] + city_radius - inner_point_offset[connection_idx], - city_pos[1] + connection_slots[connection_idx]) - out_tmp_coordinates = ( - city_pos[0] + city_radius, city_pos[1] + connection_slots[connection_idx]) - if direction == 3: - tmp_coordinates = ( - city_pos[0] + connection_slots[connection_idx], - city_pos[1] - city_radius + inner_point_offset[connection_idx]) - out_tmp_coordinates = ( - city_pos[0] + connection_slots[connection_idx], city_pos[1] - city_radius) - connection_points_coordinates_inner[direction].append(tmp_coordinates) - if connection_idx in range(start_idx, start_idx + number_of_out_rails): - connection_points_coordinates_outer[direction].append(out_tmp_coordinates) - - inner_connection_points.append(connection_points_coordinates_inner) - outer_connection_points.append(connection_points_coordinates_outer) - - cdef list inter_city_lines = [] - cdef list city_distances, closest_neighbours - cdef int current_city_idx, direction_to_neighbour, out_direction, neighbour_idx - - for current_city_idx in np.arange(len(city_positions)): - city_distances = [] - closest_neighbours = [None for _ in range(4)] - - # compute distance to all other cities - for city_idx in range(len(city_positions)): - city_distances.append( - Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx])) - sorted_neighbours = np.argsort(city_distances) - - for neighbour in sorted_neighbours[1:]: # do not include city itself - direction_to_neighbour = direction_to_point(city_positions[current_city_idx], city_positions[neighbour]) - if closest_neighbours[direction_to_neighbour] is None: - closest_neighbours[direction_to_neighbour] = neighbour - - # early return once all 4 directions have a closest neighbour - if None not in closest_neighbours: - break - for out_direction in range(4): - if closest_neighbours[out_direction] is not None: - neighbour_idx = closest_neighbours[out_direction] - elif closest_neighbours[(out_direction - 1) % 4] is not None: - neighbour_idx = closest_neighbours[(out_direction - 1) % 4] # counter-clockwise - elif closest_neighbours[(out_direction + 1) % 4] is not None: - neighbour_idx = closest_neighbours[(out_direction + 1) % 4] # clockwise - elif closest_neighbours[(out_direction + 2) % 4] is not None: - neighbour_idx = closest_neighbours[(out_direction + 2) % 4] - - for city_out_connection_point in outer_connection_points[current_city_idx][out_direction]: - - min_connection_dist = np.inf - neighbour_connection_point = None - for direction in range(4): - current_points = outer_connection_points[neighbour_idx][direction] - for tmp_in_connection_point in current_points: - tmp_dist = Vec2dOperations.get_manhattan_distance(city_out_connection_point, - tmp_in_connection_point) - if tmp_dist < min_connection_dist: - min_connection_dist = tmp_dist - neighbour_connection_point = tmp_in_connection_point - - new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point, - rail_trans, flip_start_node_trans=False, - flip_end_node_trans=False, respect_transition_validity=False, - avoid_rail=True, - forbidden_cells=city_cells) - inter_city_lines.extend(new_line) - - # Build inner cities - cdef int i, current_city, opposite_boarder - cdef int boarder = 0 - cdef int track_id, track_nbr - cdef list free_rails = [[] for _ in range(len(city_positions))] - for current_city in range(len(city_positions)): - - # This part only works if we have keep same number of connection points for both directions - # Also only works with two connection direction at each city - for i in range(4): - if len(inner_connection_points[current_city][i]) > 0: - boarder = i - break - - opposite_boarder = (boarder + 2) % 4 - nr_of_connection_points = len(inner_connection_points[current_city][boarder]) - number_of_out_rails = len(outer_connection_points[current_city][boarder]) - start_idx = (nr_of_connection_points - number_of_out_rails) // 2 - # Connect parallel tracks - for track_id in range(nr_of_connection_points): - source = inner_connection_points[current_city][boarder][track_id] - target = inner_connection_points[current_city][opposite_boarder][track_id] - current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans) - free_rails[current_city].append(current_track) - - for track_id in range(nr_of_connection_points): - source = inner_connection_points[current_city][boarder][track_id] - target = inner_connection_points[current_city][opposite_boarder][track_id] - - # Connect parallel tracks with each other - fix_inner_nodes( - grid_map, source, rail_trans) - fix_inner_nodes( - grid_map, target, rail_trans) - - # Connect outer tracks to inner tracks - if start_idx <= track_id < start_idx + number_of_out_rails: - source_outer = outer_connection_points[current_city][boarder][track_id - start_idx] - target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx] - connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans) - connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans) - - # Populate cities - cdef int num_cities = len(city_positions) - cdef list train_stations = [[] for _ in range(num_cities)] - for current_city in range(len(city_positions)): - for track_nbr in range(len(free_rails[current_city])): - possible_location = free_rails[current_city][track_nbr][ - int(len(free_rails[current_city][track_nbr]) / 2)] - train_stations[current_city].append((possible_location, track_nbr)) - - # Fix all transition elements - - cdef cnp.ndarray rails_to_fix = np.zeros(3 * grid_map.height * grid_map.width * 2, dtype='int') - cdef int rails_to_fix_cnt = 0 - cdef list cells_to_fix = city_cells + inter_city_lines - cdef bint cell_valid - for cell in cells_to_fix: - cell_valid = grid_map.cell_neighbours_valid(cell, True) - - if not cell_valid: - rails_to_fix[3 * rails_to_fix_cnt] = cell[0] - rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1] - rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell] - - rails_to_fix_cnt += 1 - # Fix all other cells - for idx in range(rails_to_fix_cnt): - grid_map.fix_transitions((rails_to_fix[3 * idx], rails_to_fix[3 * idx + 1]), rails_to_fix[3 * idx + 2]) - - return grid_map, {'agents_hints': { - 'num_agents': num_agents, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - }} diff --git a/src/replay_memory.py b/src/replay_memory.py deleted file mode 100644 index 8a5ddd8..0000000 --- a/src/replay_memory.py +++ /dev/null @@ -1,66 +0,0 @@ -import random -from collections import namedtuple, deque, Iterable - -import torch - -Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) - - -class Episode: - memory = [] - - def reset(self): - self.memory = [] - - def push(self, *args): - self.memory.append(tuple(args)) - - def discount_rewards(self, gamma): - running_add = 0. - for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]: - running_add = running_add * gamma + reward - self.memory[i] = (state, action, running_add, *rest) - - -class ReplayBuffer: - def __init__(self, buffer_size): - self.memory = deque(maxlen=buffer_size) - - def push(self, state, action, reward, next_state, done): - self.memory.append(Transition(torch.stack(state, -1).unsqueeze(0), - action, - reward, - torch.stack(next_state, -1).unsqueeze(0), - done)) - - def push_episode(self, episode): - for step in episode.memory: - self.push(*step) - - def sample(self, batch_size, device): - experiences = random.sample(self.memory, k=batch_size) - - states = self.stack([e.state for e in experiences]).float().to(device) - actions = self.stack([e.action for e in experiences]).long().to(device) - rewards = self.stack([e.reward for e in experiences]).float().to(device) - next_states = self.stack([e.next_state for e in experiences]).float().to(device) - dones = self.stack([[v for k, v in e.done.items() - if not hasattr(k, 'startswith') - or not k.startswith('_')] - for e in experiences]).float().to(device) - - return states, actions, rewards, next_states, dones - - def stack(self, states, dim=0): - if isinstance(states[0], Iterable): - if isinstance(states[0][0], list): - return torch.stack([self.stack(st, -1) for st in states], dim) - if isinstance(states[0], torch.Tensor): - return torch.stack(states, 0) - if isinstance(states[0], Iterable): - return torch.stack([self.stack(st, dim) for st in states], dim) - return torch.tensor(states) - return torch.tensor(states).view(len(states), 1) - - def __len__(self): - return len(self.memory) diff --git a/src/tree_observation.py b/src/tree_observation.py deleted file mode 100644 index e69de29..0000000 From ea376d4041c09d3973292303634103a3137ca35c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 18 Jul 2020 19:36:37 +0200 Subject: [PATCH 41/75] fix: remove NaN's --- src/agent.py | 25 ++++----- src/cythonize.sh | 3 +- src/model.py | 111 +++++++++++++------------------------- src/observation_utils.pyx | 31 ++++++----- src/train.py | 6 +-- 5 files changed, 72 insertions(+), 104 deletions(-) diff --git a/src/agent.py b/src/agent.py index f720358..a1a143c 100644 --- a/src/agent.py +++ b/src/agent.py @@ -48,13 +48,16 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s kernel_size, squeeze_heads, debug=False).to(device) - try: - self.policy = torch.jit.script(self.policy) - self.old_policy = torch.jit.script(self.old_policy) - except: - import traceback - traceback.print_exc() + if CUDA: print("NO JIT") + else: + try: + self.policy = torch.jit.script(self.policy) + self.old_policy = torch.jit.script(self.old_policy) + except: + import traceback + traceback.print_exc() + print("NO JIT") self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) @@ -117,16 +120,14 @@ def step(self, state, action, agent_done, collision, step_reward=0, collision_re def learn(self, states, actions, rewards): self.policy.train() actions.unsqueeze_(1) - responsible_outputs = self.policy(states).gather(1, actions) - old_responsible_outputs = self.old_policy(states).gather(1, actions) + responsible_outputs = self.policy(states, True).gather(1, actions) + old_responsible_outputs = self.old_policy(states, True).gather(1, actions) old_responsible_outputs.detach_() ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + ratio.squeeze_(1) clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() - - # rewards = rewards - rewards.mean() + loss = -torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).mean() - # Compute loss and perform a gradient step self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer.zero_grad() loss.backward() diff --git a/src/cythonize.sh b/src/cythonize.sh index b7762f5..4aeff8f 100644 --- a/src/cythonize.sh +++ b/src/cythonize.sh @@ -10,5 +10,4 @@ function compile { } compile observation_utils -compile rail_env -compile rail_generators \ No newline at end of file +compile rail_env \ No newline at end of file diff --git a/src/model.py b/src/model.py index 073eac2..faeda7c 100644 --- a/src/model.py +++ b/src/model.py @@ -8,87 +8,48 @@ @torch.jit.script def mish(fn_input: torch.Tensor) -> torch.Tensor: return fn_input * torch.tanh(torch.nn.functional.softplus(fn_input)) + + @torch.jit.script def nothing(x): return x + class Mish(torch.nn.Module): def forward(self, fn_input: torch.Tensor) -> torch.Tensor: return mish(fn_input) -class WeightDropConv(torch.nn.Module): - """ - Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. - - Args: - weight_dropout (float): The probability a weight will be dropped. - """ - - def __init__(self, in_features: int, out_features: int, kernel_size: typing.Union[int, tuple] = 1, bias=True, - weight_dropout=0.1, groups=1, - padding=0, dilation=1, function=torch.nn.functional.conv1d, stride=1): - super().__init__() - self.weight_dropout = weight_dropout - if in_features % groups != 0: - print(f"[ERROR] Unable to get weight for in={in_features},groups={groups}. Make sure they are divisible.") - if out_features % groups != 0: - print(f"[ERROR] Unable to get weight for out={out_features},groups={groups}. Make sure they are divisible.") - if isinstance(kernel_size, int): - self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features // groups, kernel_size)) - else: - self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features // groups, *kernel_size)) - if bias: - self.bias = torch.nn.Parameter(torch.Tensor(out_features)) - else: - self.register_parameter('bias', None) - self.padding = padding - self.dilation = dilation - self.groups = groups - self.stride = stride - self._function = function - - def forward(self, fn_input): - drop = torch.nn.functional.dropout(self.weight, self.weight_dropout, self.training) - if drop.dtype != self.weight.dtype: - drop = drop.to(self.weight.dtype) - return self._function(fn_input, - drop, - bias=self.bias, - padding=self.padding, - dilation=self.dilation, - groups=self.groups, - stride=self.stride) - - class SeparableConvolution(torch.nn.Module): def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, - bias=False, dim=1, stride=1): + bias=False, dim=1, stride=1, dropout=0.1): super(SeparableConvolution, self).__init__() - self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else all(k>1 for k in kernel_size) - function = getattr(torch.nn.functional, f'conv{dim}d') + self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else all(k > 1 for k in kernel_size) + conv = getattr(torch.nn, f'Conv{dim}d') norm = getattr(torch.nn, f'BatchNorm{dim}d') if isinstance(kernel_size, int): kernel_size = (kernel_size,) * dim if self.depthwise: - self.depthwise_conv = WeightDropConv(in_features, in_features, - kernel_size, - padding=padding, - groups=in_features, - dilation=dilation, - bias=False, - function=function, - stride=stride) + self.depthwise_conv = conv(in_features, in_features, + kernel_size, + padding=padding, + groups=in_features, + dilation=dilation, + bias=False, + stride=stride) self.mid_norm = norm(in_features) else: self.depthwise_conv = nothing self.mid_norm = nothing - self.pointwise_conv = WeightDropConv(in_features, out_features, (1,)*dim , bias=bias, function=function) + self.pointwise_conv = conv(in_features, out_features, (1,) * dim, bias=bias) self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + f'dilation={dilation}, padding={padding})') + self.dropout = dropout * (in_features == out_features and (stride == 1 or all(stride) == 1)) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + if torch.rand(1) < self.dropout: + return fn_input if self.depthwise: fn_input = self.mid_norm(self.depthwise_conv(fn_input)) return self.pointwise_conv(fn_input) @@ -103,47 +64,51 @@ def __repr__(self): class BasicBlock(torch.nn.Module): def __init__(self, in_features, out_features, stride, init_norm=False): super(BasicBlock, self).__init__() - self.init_norm = torch.nn.BatchNorm3d(in_features) if init_norm else None + self.activate = init_norm + self.init_norm = torch.nn.InstanceNorm3d(in_features) if init_norm else nothing self.init_conv = SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), stride=(stride, stride, 1), dim=3) - self.mid_norm = torch.nn.BatchNorm3d(out_features) + self.mid_norm = torch.nn.InstanceNorm3d(out_features) self.end_conv = SeparableConvolution(out_features, out_features, (3, 3, 1), (1, 1, 0), dim=3) - self.shortcut = (None + self.shortcut = (nothing if stride == 1 and in_features == out_features else SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), stride=(stride, stride, 1), dim=3)) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - out = self.init_conv(fn_input if self.init_norm is None else mish(self.init_norm(fn_input))) + out = self.init_conv(fn_input if self.activate is None else mish(self.init_norm(fn_input))) out = mish(self.mid_norm(out)) out = self.end_conv(out) - if self.shortcut is not None: - fn_input = self.shortcut(fn_input) + fn_input = self.shortcut(fn_input) out = out + fn_input return out class ConvNetwork(torch.nn.Module): def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, cat=True, - debug=True): + debug=True, embedding_size=1): super(ConvNetwork, self).__init__() _ = state_size - state_size = 2*21 - self.net = torch.nn.ModuleList([BasicBlock(state_size if not i else hidden_size, hidden_size, 2, True) - for i in range(depth)]) - self.init_norm = torch.nn.BatchNorm1d(hidden_size) + state_size = 2 * 21 + self.embedding = torch.nn.Embedding(5, embedding_size) + self.net = torch.nn.ModuleList([BasicBlock(embedding_size * state_size if not i else hidden_size, + hidden_size, + 2, + init_norm=bool(i)) + for i in range(depth)]) + self.init_norm = torch.nn.InstanceNorm1d(hidden_size) self.linear0 = torch.nn.Conv1d(hidden_size, hidden_size, 1, bias=False) - self.mid_norm = torch.nn.BatchNorm1d(hidden_size) + self.mid_norm = torch.nn.InstanceNorm1d(hidden_size) self.linear1 = torch.nn.Conv1d(hidden_size, action_size, 1) print(self) - def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - out = fn_input + def forward(self, fn_input: torch.Tensor, softmax=False) -> torch.Tensor: + out = fn_input.float() for module in self.net: out = module(out) out = out.mean((2, 3)) out = self.linear1(mish(self.mid_norm(self.linear0(mish(self.init_norm(out)))))) - return out + return torch.nn.functional.softmax(out, 1) if softmax else out def init(module: torch.nn.Module): @@ -161,7 +126,7 @@ class Residual(torch.nn.Module): def __init__(self, features): super(Residual, self).__init__() self.norm = torch.nn.BatchNorm1d(features) - self.conv = WeightDropConv(features, 2 * features) + self.conv = torch.nn.Conv1d(features, 2 * features) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out, exc = self.conv(mish(self.norm(fn_input))).chunk(2, 1) @@ -177,7 +142,7 @@ def QNetwork(state_size, action_size, hidden_factor=16, depth=4, kernel_size=7, Residual(11 * hidden_factor), torch.nn.BatchNorm1d(11 * hidden_factor), Mish(), - WeightDropConv(11 * hidden_factor, action_size, 1)) + torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) print(model) if debug: parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index abd7256..fd50289 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -1,4 +1,5 @@ from collections import defaultdict + cimport numpy as cnp import numpy as np import torch @@ -70,21 +71,23 @@ class GlobalObsForRailEnv(ObservationBuilder): self._custom_rail_obs = None def reset(self): if self._custom_rail_obs is None: - self._custom_rail_obs = np.zeros((1, self.env.height + 2*self.size, self.env.width + 2*self.size, 16)) + self._custom_rail_obs = np.zeros((1, self.env.height + 2 * self.size, self.env.width + 2 * self.size, 16)) self._custom_rail_obs[0, self.size:-self.size, self.size:-self.size] = np.array([[[[1 if digit == '1' else 0 - for digit in - f'{self.env.rail.get_full_transitions(i, j):016b}'] - for j in range(self.env.width)] - for i in range(self.env.height)]], - dtype=np.float32) + for digit in + f'{self.env.rail.get_full_transitions(i, j):016b}'] + for j in + range(self.env.width)] + for i in + range(self.env.height)]], + dtype=np.int64) def get_many(self, list trash): cdef int agent_count = len(self.env.agents) cdef cnp.ndarray obs_agents_state = np.zeros((agent_count, - self.env.height, - self.env.width, - 5), dtype=np.float32) + self.env.height, + self.env.width, + 5), dtype=np.int64) cdef int i, agent_id cdef tuple pos, agent_virtual_position for agent_id, agent in enumerate(self.env.agents): @@ -130,9 +133,9 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): def get_many(self, list trash): cdef int agent_count = len(self.env.agents) obs_agents_state = np.zeros((agent_count, - self.size * 2 + 1, - self.size * 2 + 1, - 21), dtype=np.float32) + self.size * 2 + 1, + self.size * 2 + 1, + 21), dtype=np.int64) cdef int i, agent_id cdef tuple agent_virtual_position for agent_id, agent in enumerate(self.env.agents): @@ -146,8 +149,8 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): continue x0, y0, x1, y1 = (agent_virtual_position[0], agent_virtual_position[1], - agent_virtual_position[0] + 2*self.size + 1, - agent_virtual_position[1] + 2*self.size + 1) + agent_virtual_position[0] + 2 * self.size + 1, + agent_virtual_position[1] + 2 * self.size + 1) obs_agents_state[agent_id, :, :, 5:] = self._custom_rail_obs[0, x0:x1, y0:y1] obs_agents_state[agent_id, :, :, 0:4] = -1 diff --git a/src/train.py b/src/train.py index dbc105d..020682f 100644 --- a/src/train.py +++ b/src/train.py @@ -103,7 +103,7 @@ random_seed=i) for i in range(BATCH_SIZE)] env = environments[0] - +torch.autograd.set_detect_anomaly(True) # After training we want to render the results so we also load a renderer # Add some variables to keep track of the progress @@ -174,7 +174,7 @@ def normalize(observation, target_tensor): score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) if flags.global_environment: - agent_obs = torch.as_tensor([list(o.values()) for o in obs]).float().to(device) + agent_obs = torch.as_tensor([list(o.values()) for o in obs]).long().to(device) else: agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) @@ -233,7 +233,7 @@ def normalize(observation, target_tensor): break if flags.global_environment: - agent_obs = torch.as_tensor([list(o.values()) for o in obs]).float().to(device) + agent_obs = torch.as_tensor([list(o.values()) for o in obs]).long().to(device) else: normalize(obs, agent_obs) From 38303531bdba9e602cb6b92500bca505167d723b Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 09:29:06 +0200 Subject: [PATCH 42/75] perf(agent): add JIT, remove unnecessary list comprehension --- src/agent.py | 71 ++++++++++++++++++++++------------------------------ 1 file changed, 30 insertions(+), 41 deletions(-) diff --git a/src/agent.py b/src/agent.py index a1a143c..d1ae6e1 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,13 +1,14 @@ +import math import pickle -import random +import numpy as np import torch from torch_optimizer import Yogi as Optimizer try: - from .model import QNetwork, ConvNetwork + from .model import QNetwork, ConvNetwork, init except: - from model import QNetwork, ConvNetwork + from model import QNetwork, ConvNetwork, init import os BUFFER_SIZE = 500_000 @@ -15,7 +16,7 @@ GAMMA = 0.998 TAU = 1e-3 CLIP_FACTOR = 0.2 -LR = 4e-5 +LR = 1e-4 UPDATE_EVERY = 1 DOUBLE_DQN = False CUDA = True @@ -23,11 +24,9 @@ device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") - - class Agent: def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, - use_global=False): + use_global=False, softmax=True, debug=True): self.action_size = action_size # Q-Network @@ -40,24 +39,33 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s hidden_factor, model_depth, kernel_size, - squeeze_heads).to(device) + squeeze_heads, + softmax=softmax).to(device) self.old_policy = network(state_size, action_size, hidden_factor, model_depth, kernel_size, squeeze_heads, + softmax=softmax, debug=False).to(device) - if CUDA: + if debug: + print(self.policy) + + parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, self.policy.parameters())) + digits = int(math.log10(parameters)) + number_string = " kMGTPEZY"[digits // 3] + + print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}" + f"{number_string} parameters") + self.policy.apply(init) + try: + self.policy = torch.jit.script(self.policy) + self.old_policy = torch.jit.script(self.old_policy) + except: + import traceback + traceback.print_exc() print("NO JIT") - else: - try: - self.policy = torch.jit.script(self.policy) - self.old_policy = torch.jit.script(self.old_policy) - except: - import traceback - traceback.print_exc() - print("NO JIT") self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) @@ -68,34 +76,14 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s def reset(self): self.finished = False - # Decide on an action to take in the environment - - def act(self, state, eps=0.): - agent_count = len(state) - state = torch.stack(state, -1).unsqueeze(0).to(device) - self.policy.eval() - with torch.no_grad(): - action_values = self.policy(state) - - # Epsilon-greedy action selection - return [torch.argmax(action_values[:, :, i], 1).item() - if random.random() > eps - else torch.randint(self.action_size, ()).item() - for i in range(agent_count)] - - def multi_act(self, state, eps=0.): - agent_count = state.size(-1) + def multi_act(self, state): state = state.to(device) self.policy.eval() with torch.no_grad(): action_values = self.policy(state) # Epsilon-greedy action selection - return [[torch.argmax(act[:, i], 0).item() - if random.random() > eps - else torch.randint(self.action_size, ()).item() - for i in range(agent_count)] - for act in action_values.__iter__()] + return action_values.argmax(1).detach().cpu().numpy() # Record the results of the agent's action and update the model @@ -120,8 +108,9 @@ def step(self, state, action, agent_done, collision, step_reward=0, collision_re def learn(self, states, actions, rewards): self.policy.train() actions.unsqueeze_(1) - responsible_outputs = self.policy(states, True).gather(1, actions) - old_responsible_outputs = self.old_policy(states, True).gather(1, actions) + responsible_outputs = self.policy(states).gather(1, actions) + with torch.no_grad(): + old_responsible_outputs = self.old_policy(states).gather(1, actions) old_responsible_outputs.detach_() ratio = responsible_outputs / (old_responsible_outputs + 1e-5) ratio.squeeze_(1) From e83378e886881a9c4550bc2f7bf5a8c8039a1995 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 09:29:35 +0200 Subject: [PATCH 43/75] perf(model): use instancenorm, add message box --- src/model.py | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/model.py b/src/model.py index faeda7c..c94245b 100644 --- a/src/model.py +++ b/src/model.py @@ -23,11 +23,11 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: class SeparableConvolution(torch.nn.Module): def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, - bias=False, dim=1, stride=1, dropout=0.1): + bias=False, dim=1, stride=1, dropout=0.25): super(SeparableConvolution, self).__init__() self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else all(k > 1 for k in kernel_size) conv = getattr(torch.nn, f'Conv{dim}d') - norm = getattr(torch.nn, f'BatchNorm{dim}d') + norm = getattr(torch.nn, f'InstanceNorm{dim}d') if isinstance(kernel_size, int): kernel_size = (kernel_size,) * dim if self.depthwise: @@ -38,13 +38,13 @@ def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tup dilation=dilation, bias=False, stride=stride) - self.mid_norm = norm(in_features) + self.mid_norm = norm(in_features, affine=True) else: self.depthwise_conv = nothing self.mid_norm = nothing self.pointwise_conv = conv(in_features, out_features, (1,) * dim, bias=bias) self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' - + f'dilation={dilation}, padding={padding})') + + f'dilation={dilation}, padding={padding}, stride={stride})') self.dropout = dropout * (in_features == out_features and (stride == 1 or all(stride) == 1)) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: @@ -62,53 +62,55 @@ def __repr__(self): class BasicBlock(torch.nn.Module): - def __init__(self, in_features, out_features, stride, init_norm=False): + def __init__(self, in_features, out_features, stride, init_norm=False, message_box=None): super(BasicBlock, self).__init__() self.activate = init_norm - self.init_norm = torch.nn.InstanceNorm3d(in_features) if init_norm else nothing + self.init_norm = torch.nn.InstanceNorm3d(in_features, affine=True) if init_norm else nothing self.init_conv = SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), stride=(stride, stride, 1), dim=3) - self.mid_norm = torch.nn.InstanceNorm3d(out_features) + self.mid_norm = torch.nn.InstanceNorm3d(out_features, affine=True) self.end_conv = SeparableConvolution(out_features, out_features, (3, 3, 1), (1, 1, 0), dim=3) self.shortcut = (nothing if stride == 1 and in_features == out_features else SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), stride=(stride, stride, 1), dim=3)) + self.message_box = int(out_features ** 0.5) if message_box is None else message_box def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out = self.init_conv(fn_input if self.activate is None else mish(self.init_norm(fn_input))) out = mish(self.mid_norm(out)) - out = self.end_conv(out) + out: torch.Tensor = self.end_conv(out) + out[:, :self.message_box] = out[:, :self.message_box].mean(-1, + keepdim=True).expand(-1, -1, -1, -1, + fn_input.size(-1)) fn_input = self.shortcut(fn_input) out = out + fn_input return out - class ConvNetwork(torch.nn.Module): def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, cat=True, - debug=True, embedding_size=1): + debug=True, softmax=False): super(ConvNetwork, self).__init__() _ = state_size state_size = 2 * 21 - self.embedding = torch.nn.Embedding(5, embedding_size) - self.net = torch.nn.ModuleList([BasicBlock(embedding_size * state_size if not i else hidden_size, + self.net = torch.nn.ModuleList([BasicBlock(state_size if not i else hidden_size, hidden_size, 2, init_norm=bool(i)) for i in range(depth)]) - self.init_norm = torch.nn.InstanceNorm1d(hidden_size) + self.init_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) self.linear0 = torch.nn.Conv1d(hidden_size, hidden_size, 1, bias=False) - self.mid_norm = torch.nn.InstanceNorm1d(hidden_size) + self.mid_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) self.linear1 = torch.nn.Conv1d(hidden_size, action_size, 1) - print(self) + self.softmax = softmax - def forward(self, fn_input: torch.Tensor, softmax=False) -> torch.Tensor: - out = fn_input.float() + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + out = fn_input for module in self.net: out = module(out) out = out.mean((2, 3)) out = self.linear1(mish(self.mid_norm(self.linear0(mish(self.init_norm(out)))))) - return torch.nn.functional.softmax(out, 1) if softmax else out + return torch.nn.functional.softmax(out, 1) if self.softmax else out def init(module: torch.nn.Module): @@ -144,12 +146,7 @@ def QNetwork(state_size, action_size, hidden_factor=16, depth=4, kernel_size=7, Mish(), torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) print(model) - if debug: - parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())) - digits = int(math.log10(parameters)) - number_string = " kMGTPEZY"[digits // 3] - print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}{number_string} parameters") model.apply(init) try: model = torch.jit.script(model) From 4c35628153b440bff53eeff5a0d215a5ae5c11ba Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 09:29:55 +0200 Subject: [PATCH 44/75] style: remove unused variables --- src/train.py | 43 ++++++++++++++----------------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/src/train.py b/src/train.py index 020682f..d3e2975 100644 --- a/src/train.py +++ b/src/train.py @@ -34,24 +34,17 @@ parser.add_argument("--train", type=boolean, default=True, help="Whether to train the model or just evaluate it") parser.add_argument("--load-model", default=False, action='store_true', help="Whether to load the model from the last checkpoint") -parser.add_argument("--load-railways", type=boolean, default=True, - help="Whether to load in pre-generated railway networks") -parser.add_argument("--report-interval", type=int, default=100, help="Iterations between reports") parser.add_argument("--render-interval", type=int, default=0, help="Iterations between renders") # Environment parameters -parser.add_argument("--tree-depth", type=int, default=3, help="Depth of the observation tree") +parser.add_argument("--tree-depth", type=int, default=2, help="Depth of the observation tree") parser.add_argument("--model-depth", type=int, default=3, help="Depth of the observation tree") -parser.add_argument("--hidden-factor", type=int, default=16, help="Depth of the observation tree") +parser.add_argument("--hidden-factor", type=int, default=48, help="Depth of the observation tree") parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") - -parser.add_argument("--environment-width", type=int, default=35, help="Depth of the observation tree") -parser.add_argument("--agent-factor", type=float, default=1.1, help="Depth of the observation tree") +parser.add_argument("--observation-size", type=int, default=4, help="Depth of the observation tree") # Training parameters -parser.add_argument("--num-episodes", type=int, default=10 ** 6, help="Number of episodes to train for") -parser.add_argument("--epsilon-decay", type=float, default=0, help="Decay factor for epsilon-greedy exploration") parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") parser.add_argument("--collision-reward", type=float, default=-2, help="Depth of the observation tree") parser.add_argument("--global-environment", type=boolean, default=True, help="Depth of the observation tree") @@ -80,14 +73,11 @@ flags.squeeze_heads, flags.global_environment) if flags.load_model: - start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) + start,_ = agent.load(project_root / 'checkpoints', 0, 1.0) else: - start, eps = 0, 1.0 + start = 0 # We need to either load in some pre-generated railways from disk, or else create a random railway generator. -if flags.load_railways: - rail_generator, schedule_generator = load_precomputed_railways(project_root, start) -else: - rail_generator, schedule_generator = create_random_railways(flags.environment_width, flags.agent_factor) +rail_generator, schedule_generator = load_precomputed_railways(project_root, start) # Create the Flatland environment environments = [RailEnv(width=40, height=40, number_of_agents=1, @@ -95,7 +85,7 @@ schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params( MalfunctionParameters(1 / 500, 20, 50)), - obs_builder_object=((LocalObsForRailEnv(4) + obs_builder_object=((LocalObsForRailEnv(flags.observation_size) if flags.local_environment else GlobalObsForRailEnv) if flags.global_environment @@ -112,9 +102,6 @@ agent_action_buffer = [] start_time = time.time() -if not flags.train: - eps = 0.0 - # Helper function to detect collisions ACTIONS = {0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S'} @@ -167,14 +154,16 @@ def normalize(observation, target_tensor): POOL = multiprocessing.Pool() # Main training loop -for episode in range(start + 1, flags.num_episodes + 1): +episode = 0 +while True: + episode += 1 agent.reset() obs, info = zip(*[env.reset() for env in environments]) score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) if flags.global_environment: - agent_obs = torch.as_tensor([list(o.values()) for o in obs]).long().to(device) + agent_obs = torch.as_tensor([list(o.values()) for o in obs], dtype=torch.float, device=device) else: agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) @@ -193,7 +182,7 @@ def normalize(observation, target_tensor): else: input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) if any(any(inf['action_required']) for inf in info): - ret_action = agent.multi_act(input_tensor, eps=eps) + ret_action = agent.multi_act(input_tensor) else: ret_action = update_values for idx, act_list in enumerate(ret_action): @@ -233,7 +222,7 @@ def normalize(observation, target_tensor): break if flags.global_environment: - agent_obs = torch.as_tensor([list(o.values()) for o in obs]).long().to(device) + agent_obs = torch.as_tensor([list(o.values()) for o in obs], dtype=torch.float, device=device) else: normalize(obs, agent_obs) @@ -244,9 +233,6 @@ def normalize(observation, target_tensor): # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break - # Epsilon decay - if flags.train: - eps = max(0.01, flags.epsilon_decay * eps) current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) @@ -258,9 +244,8 @@ def normalize(observation, target_tensor): f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' - f' | Epsilon: {eps:.2f}' f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):.4f}s', end='') print("") if flags.train: - agent.save(project_root / 'checkpoints', episode, eps) + agent.save(project_root / 'checkpoints', episode) From bc1cb82a92781be66a8777c2d07781ccf655a32c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 10:41:48 +0200 Subject: [PATCH 45/75] perf(observation): use int8 --- src/observation_utils.pyx | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index fd50289..1951184 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -65,10 +65,11 @@ class GlobalObsForRailEnv(ObservationBuilder): - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets (flag only, no counter!). """ - def __init__(self): + def __init__(self, data_type=np.int8): super(GlobalObsForRailEnv, self).__init__() self.size = 0 self._custom_rail_obs = None + self.data_type = data_type def reset(self): if self._custom_rail_obs is None: self._custom_rail_obs = np.zeros((1, self.env.height + 2 * self.size, self.env.width + 2 * self.size, 16)) @@ -80,14 +81,14 @@ class GlobalObsForRailEnv(ObservationBuilder): range(self.env.width)] for i in range(self.env.height)]], - dtype=np.int64) + dtype=self.data_type) def get_many(self, list trash): cdef int agent_count = len(self.env.agents) cdef cnp.ndarray obs_agents_state = np.zeros((agent_count, self.env.height, self.env.width, - 5), dtype=np.int64) + 5), dtype=self.data_type) cdef int i, agent_id cdef tuple pos, agent_virtual_position for agent_id, agent in enumerate(self.env.agents): @@ -135,7 +136,7 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): obs_agents_state = np.zeros((agent_count, self.size * 2 + 1, self.size * 2 + 1, - 21), dtype=np.int64) + 21), dtype=np.float32) cdef int i, agent_id cdef tuple agent_virtual_position for agent_id, agent in enumerate(self.env.agents): @@ -178,9 +179,7 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): dist1 = agent_virtual_position[1] - init_pos[1] if abs(dist0) < self.size and abs(dist1) < self.size: obs_agents_state[agent_id, dist0 + self.size, dist1 + self.size, 4] += 1 - return {i: arr - for i, arr in - enumerate(obs_agents_state)} + return obs_agents_state class TreeObservation(ObservationBuilder): @@ -323,8 +322,8 @@ class TreeObservation(ObservationBuilder): for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ (distance, agent.malfunction_data['malfunction']) - - return super().get_many(handles) + cdef dict data = super().get_many(handles) + return tuple(tuple(dat.values()) for dat in data) # Compute the observation for a single agent def get(self, int handle): From cdacc24ea0cfc61b872488fd99012737eb06d816 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 10:42:05 +0200 Subject: [PATCH 46/75] perf(rail-env): remove type enforcing --- src/rail_env.pyx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index 842e884..97221be 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -337,8 +337,7 @@ class RailEnv(Environment): 'status': {i: agent.status for i, agent in enumerate(self.agents)} } # Return the new observation vectors for each agent - cdef dict observation_dict = self._get_observations() - return observation_dict, info_dict + return self._get_observations(), info_dict def _fix_agent_after_malfunction(self, agent: EnvAgent): """ From e0b244fb2c44742534c36145f497507b170e6cdd Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 10:42:22 +0200 Subject: [PATCH 47/75] perf(train): add support for np.ndarray observation --- src/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/train.py b/src/train.py index d3e2975..1523083 100644 --- a/src/train.py +++ b/src/train.py @@ -163,7 +163,7 @@ def normalize(observation, target_tensor): score, steps_taken, collision = 0, 0, False agent_count = len(obs[0]) if flags.global_environment: - agent_obs = torch.as_tensor([list(o.values()) for o in obs], dtype=torch.float, device=device) + agent_obs = torch.as_tensor(obs, dtype=torch.float, device=device) else: agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) @@ -201,7 +201,7 @@ def normalize(observation, target_tensor): # Check for collisions and episode completion all_done = (step == (max_steps - 1)) or any(d['__all__'] for d in done) - if any(is_collision(a, i) for i, o in enumerate(obs) for a in o): + if any(is_collision(a, i) for i in range(BATCH_SIZE) for a in range(agent_count)): collision = True # done['__all__'] = True @@ -222,7 +222,7 @@ def normalize(observation, target_tensor): break if flags.global_environment: - agent_obs = torch.as_tensor([list(o.values()) for o in obs], dtype=torch.float, device=device) + agent_obs = torch.as_tensor(obs, dtype=torch.float, device=device) else: normalize(obs, agent_obs) From ae8073c11b514ee05c9fc1aba660157a54cd2131 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 15:30:26 +0200 Subject: [PATCH 48/75] feat: add finish rate --- src/model.py | 2 +- src/railway_utils.py | 1 + src/train.py | 29 ++++++++++++++++------------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/model.py b/src/model.py index c94245b..35293f6 100644 --- a/src/model.py +++ b/src/model.py @@ -72,7 +72,7 @@ def __init__(self, in_features, out_features, stride, init_norm=False, message_b self.end_conv = SeparableConvolution(out_features, out_features, (3, 3, 1), (1, 1, 0), dim=3) self.shortcut = (nothing if stride == 1 and in_features == out_features - else SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), + else SeparableConvolution(in_features, out_features, (1, 1, 1), (0, 0, 0), stride=(stride, stride, 1), dim=3)) self.message_box = int(out_features ** 0.5) if message_box is None else message_box diff --git a/src/railway_utils.py b/src/railway_utils.py index 74df209..182e8aa 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -13,6 +13,7 @@ class Generator: + def __init__(self, path, start_index=0): self.path = path self.index = start_index diff --git a/src/train.py b/src/train.py index 1523083..b8e08d0 100644 --- a/src/train.py +++ b/src/train.py @@ -5,7 +5,6 @@ import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters -from flatland.envs.observations import GlobalObsForRailEnv from pathos import multiprocessing torch.jit.optimized_execution(True) @@ -73,7 +72,7 @@ flags.squeeze_heads, flags.global_environment) if flags.load_model: - start,_ = agent.load(project_root / 'checkpoints', 0, 1.0) + start, _ = agent.load(project_root / 'checkpoints', 0, 1.0) else: start = 0 # We need to either load in some pre-generated railways from disk, or else create a random railway generator. @@ -172,7 +171,8 @@ def normalize(observation, target_tensor): agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] # Run an episode - max_steps = 8 * env.width + env.height + city_count = (env.width * env.height)//300 + max_steps = int(8 * (env.width + env.height + agent_count/city_count)) for step in range(max_steps): update_values = [[False] * agent_count for _ in range(BATCH_SIZE)] action_dict = [{} for _ in range(BATCH_SIZE)] @@ -200,17 +200,14 @@ def normalize(observation, target_tensor): score += sum(sum(r.values()) for r in rewards) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion - all_done = (step == (max_steps - 1)) or any(d['__all__'] for d in done) - if any(is_collision(a, i) for i in range(BATCH_SIZE) for a in range(agent_count)): - collision = True - # done['__all__'] = True - + all_done = (step == (max_steps - 1)) or all(d['__all__'] for d in done) + collision = [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)] # Update replay buffer and train agent - if flags.train and (any(update_values) or all_done or all(any(d) for d in done)): + if flags.train: agent.step(input_tensor, agent_action_buffer, done, - [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)], + collision, flags.step_reward, flags.collision_reward) agent_obs_buffer = agent_obs.clone() @@ -233,17 +230,23 @@ def normalize(observation, target_tensor): # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break - - current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, int(collision), episode) + current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, + sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count), + episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) + current_done, mean_done = get_means(current_done, mean_done, + sum(d[i] for d in done for i in range(agent_count)) / ( + BATCH_SIZE * agent_count), + episode) print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6} - Agents: {agent_count:>3}' - f' | Score: {current_score:.4f}, {mean_score:.4f}' + f' | Score: {current_score:.4f}, {mean_score:.4f}' f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' + f' | Finished: {100 * current_done:5.2f}%, {100 * mean_done:5.2f}%' f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):.4f}s', end='') print("") From f50e92f1b4f9a0de11ece74c451669e526d594c5 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 20:13:02 +0200 Subject: [PATCH 49/75] style(observation-utils): add return_array parameter (backwards compatability) --- src/observation_utils.pyx | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index 1951184..bc5c9fa 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -65,11 +65,12 @@ class GlobalObsForRailEnv(ObservationBuilder): - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets (flag only, no counter!). """ - def __init__(self, data_type=np.int8): + def __init__(self, data_type=np.int8, return_array=True): super(GlobalObsForRailEnv, self).__init__() self.size = 0 self._custom_rail_obs = None self.data_type = data_type + self.return_array = return_array def reset(self): if self._custom_rail_obs is None: self._custom_rail_obs = np.zeros((1, self.env.height + 2 * self.size, self.env.width + 2 * self.size, 16)) @@ -122,14 +123,14 @@ class GlobalObsForRailEnv(ObservationBuilder): # fifth channel: all ready to depart on this position if other_agent.status == RailAgentStatus.READY_TO_DEPART: obs_agents_state[(agent_id,) + other_agent.initial_position + (4,)] += 1 - return {i: arr - for i, arr in - enumerate(np.concatenate([np.repeat(self.rail_obs, agent_count, 0), obs_agents_state], -1))} + if self.return_array: + return obs_agents_state + return dict(enumerate(np.concatenate([np.repeat(self.rail_obs, agent_count, 0), obs_agents_state], -1))) class LocalObsForRailEnv(GlobalObsForRailEnv): - def __init__(self, size=7): - super(LocalObsForRailEnv, self).__init__() + def __init__(self, size=7, return_array=True): + super(LocalObsForRailEnv, self).__init__(return_array=return_array) self.size = size def get_many(self, list trash): cdef int agent_count = len(self.env.agents) @@ -179,8 +180,9 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): dist1 = agent_virtual_position[1] - init_pos[1] if abs(dist0) < self.size and abs(dist1) < self.size: obs_agents_state[agent_id, dist0 + self.size, dist1 + self.size, 4] += 1 - return obs_agents_state - + if self.return_array: + return obs_agents_state + return dict(enumerate(obs_agents_state)) class TreeObservation(ObservationBuilder): def __init__(self, max_depth): @@ -323,7 +325,8 @@ class TreeObservation(ObservationBuilder): self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ (distance, agent.malfunction_data['malfunction']) cdef dict data = super().get_many(handles) - return tuple(tuple(dat.values()) for dat in data) + return data + #return tuple(data.values()) # Compute the observation for a single agent def get(self, int handle): From dbee685b29a6d532bfa2fd4c246745a89a4ec23f Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 19 Jul 2020 20:13:27 +0200 Subject: [PATCH 50/75] perf(train): mildly improve sum iterator --- src/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/train.py b/src/train.py index b8e08d0..2cb3f3a 100644 --- a/src/train.py +++ b/src/train.py @@ -197,7 +197,7 @@ def normalize(observation, target_tensor): # Environment step obs, rewards, done, info = tuple(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) - score += sum(sum(r.values()) for r in rewards) / (agent_count * BATCH_SIZE) + score += sum(i for r in rewards for i in r.values()) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion all_done = (step == (max_steps - 1)) or all(d['__all__'] for d in done) @@ -229,7 +229,6 @@ def normalize(observation, target_tensor): # render() # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break - current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count), episode) From b0107f6db80c2642a9da0dedb79aa6817b41c0a8 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 01:39:14 +0200 Subject: [PATCH 51/75] fix(train): use correct variables for running stats --- src/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 2cb3f3a..d0f1aba 100644 --- a/src/train.py +++ b/src/train.py @@ -234,7 +234,7 @@ def normalize(observation, target_tensor): episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) - current_taken, mean_taken = get_means(current_steps, mean_steps, step, episode) + current_taken, mean_taken = get_means(current_taken, mean_taken, step, episode) current_done, mean_done = get_means(current_done, mean_done, sum(d[i] for d in done for i in range(agent_count)) / ( BATCH_SIZE * agent_count), From 55c72d5b46f9f000c97b12dc34876538c4d28abd Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 01:39:29 +0200 Subject: [PATCH 52/75] style(model): remove unused deps --- src/model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/model.py b/src/model.py index 35293f6..f3e58a5 100644 --- a/src/model.py +++ b/src/model.py @@ -1,7 +1,5 @@ -import math import typing -import numpy as np import torch From 78936f96d4619aaacedf4e0b0dce51ce5dc63873 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 18:16:48 +0200 Subject: [PATCH 53/75] perf(agent): first calculate divisor (+cleanup), then do big op --- src/agent.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/agent.py b/src/agent.py index d1ae6e1..d28efca 100644 --- a/src/agent.py +++ b/src/agent.py @@ -108,15 +108,17 @@ def step(self, state, action, agent_done, collision, step_reward=0, collision_re def learn(self, states, actions, rewards): self.policy.train() actions.unsqueeze_(1) - responsible_outputs = self.policy(states).gather(1, actions) + with torch.no_grad(): - old_responsible_outputs = self.old_policy(states).gather(1, actions) + states_clone = states.clone() + states_clone.requires_grad_(False) + old_responsible_outputs = self.old_policy(states_clone).gather(1, actions) old_responsible_outputs.detach_() + responsible_outputs = self.policy(states).gather(1, actions) ratio = responsible_outputs / (old_responsible_outputs + 1e-5) ratio.squeeze_(1) clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) loss = -torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).mean() - self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer.zero_grad() loss.backward() From f695f65a5bc3a544f06800cae93c26f1077d2922 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 18:17:34 +0200 Subject: [PATCH 54/75] perf(model): fix stride --- src/model.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/model.py b/src/model.py index f3e58a5..36585a2 100644 --- a/src/model.py +++ b/src/model.py @@ -23,7 +23,7 @@ def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tup padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, bias=False, dim=1, stride=1, dropout=0.25): super(SeparableConvolution, self).__init__() - self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else all(k > 1 for k in kernel_size) + self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else any(k > 1 for k in kernel_size) conv = getattr(torch.nn, f'Conv{dim}d') norm = getattr(torch.nn, f'InstanceNorm{dim}d') if isinstance(kernel_size, int): @@ -43,7 +43,7 @@ def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tup self.pointwise_conv = conv(in_features, out_features, (1,) * dim, bias=bias) self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + f'dilation={dilation}, padding={padding}, stride={stride})') - self.dropout = dropout * (in_features == out_features and (stride == 1 or all(stride) == 1)) + self.dropout = dropout * (in_features == out_features and (stride == 1 or all(s == 1 for s in stride))) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: if torch.rand(1) < self.dropout: @@ -68,10 +68,11 @@ def __init__(self, in_features, out_features, stride, init_norm=False, message_b stride=(stride, stride, 1), dim=3) self.mid_norm = torch.nn.InstanceNorm3d(out_features, affine=True) self.end_conv = SeparableConvolution(out_features, out_features, (3, 3, 1), (1, 1, 0), dim=3) - self.shortcut = (nothing - if stride == 1 and in_features == out_features - else SeparableConvolution(in_features, out_features, (1, 1, 1), (0, 0, 0), - stride=(stride, stride, 1), dim=3)) + self.shortcut = torch.nn.Sequential() + if stride > 1: + self.shortcut.add_module("1", torch.nn.AvgPool3d((stride, stride, 1), padding=(1,1,0))) + if in_features != out_features: + self.shortcut.add_module("2", torch.nn.Conv3d(in_features, out_features, 1)) self.message_box = int(out_features ** 0.5) if message_box is None else message_box def forward(self, fn_input: torch.Tensor) -> torch.Tensor: @@ -81,10 +82,11 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out[:, :self.message_box] = out[:, :self.message_box].mean(-1, keepdim=True).expand(-1, -1, -1, -1, fn_input.size(-1)) - fn_input = self.shortcut(fn_input) - out = out + fn_input + srt = self.shortcut(fn_input) + out = out + srt return out + class ConvNetwork(torch.nn.Module): def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, cat=True, debug=True, softmax=False): From 325ae4e1519c9e9d23c25c36fc532f1172499ccf Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 18:17:53 +0200 Subject: [PATCH 55/75] fix(train): re-add support for model loading --- src/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/train.py b/src/train.py index d0f1aba..cefcddc 100644 --- a/src/train.py +++ b/src/train.py @@ -72,11 +72,11 @@ flags.squeeze_heads, flags.global_environment) if flags.load_model: - start, _ = agent.load(project_root / 'checkpoints', 0, 1.0) + start, = agent.load(project_root / 'checkpoints', 0) else: start = 0 # We need to either load in some pre-generated railways from disk, or else create a random railway generator. -rail_generator, schedule_generator = load_precomputed_railways(project_root, start) +rail_generator, schedule_generator = load_precomputed_railways(project_root, start * BATCH_SIZE) # Create the Flatland environment environments = [RailEnv(width=40, height=40, number_of_agents=1, @@ -153,7 +153,7 @@ def normalize(observation, target_tensor): POOL = multiprocessing.Pool() # Main training loop -episode = 0 +episode = start while True: episode += 1 agent.reset() From 3759e7713a37784aab2e531f9bee5c3defa9c762 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 20:48:24 +0200 Subject: [PATCH 56/75] fix(cythonize): use gcc9 instead of 7 (9 doens't work with cuda10) --- src/cythonize.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cythonize.sh b/src/cythonize.sh index 4aeff8f..58e2047 100644 --- a/src/cythonize.sh +++ b/src/cythonize.sh @@ -1,7 +1,7 @@ function compile { file=${1} cython "$file.pyx" -3 -Wextra -D - cmd="gcc-7 $file.c `python3-config --cflags --ldflags --includes --libs` -I`python -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" + cmd="gcc-9 $file.c `python3-config --cflags --ldflags --includes --libs` -I`python -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" echo "Executing $cmd" $cmd echo "Testing compilation.." From a11a470d8a461552c7b159bd6c32bd626241451b Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 21:25:25 +0200 Subject: [PATCH 57/75] perf(rail_env): remove rtol for 0 values --- src/rail_env.pyx | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index 97221be..e4dcec1 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -241,10 +241,9 @@ class RailEnv(Environment): True: Agent needs to provide an action False: Agent cannot provide an action """ - return (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + return (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0)) def reset(self) -> (Dict, Dict): """ reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) @@ -441,13 +440,17 @@ class RailEnv(Environment): self._fix_agent_after_malfunction(agent) # Check for end of episode + set global reward to all rewards! - if have_all_agents_ended: - self.dones["__all__"] = True - self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True for i_agent in range(self.get_num_agents()): self.dones[i_agent] = True + if have_all_agents_ended: + self.dones["__all__"] = True + self.rewards_dict = {i: self.global_reward * (1 - self.dones[i]) for i in range(self.get_num_agents())} + else: + for i_agent in range(self.get_num_agents()): + if self.dones[i_agent]: + self.rewards_dict[i_agent] = 0 if self.record_steps: self.record_timestep(action_dict_) @@ -493,7 +496,7 @@ class RailEnv(Environment): # Is the agent at the beginning of the cell? Then, it can take an action. # As long as the agent is malfunctioning or stopped at the beginning of the cell, # different actions may be taken! - if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): + if agent.speed_data['position_fraction'] == 0: # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: action = DO_NOTHING @@ -553,8 +556,7 @@ class RailEnv(Environment): # transition_action_on_cellexit if the cell is free. if agent.moving: agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0, - rtol=1e-03): + if agent.speed_data['position_fraction'] > 1.0-1e-3: # Perform stored action to transition to the next cell as soon as cell is free # Notice that we've already checked new_cell_valid and transition valid when we stored the action, # so we only have to check cell_free now! From c2db1e5e7e7cc041059c330ce413997132592588 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 20 Jul 2020 21:26:07 +0200 Subject: [PATCH 58/75] feat(train): improve the interface by adding better outputs, _always_ perform an action --- src/train.py | 42 +++++++++++++++--------------------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/src/train.py b/src/train.py index cefcddc..ecbe527 100644 --- a/src/train.py +++ b/src/train.py @@ -92,11 +92,10 @@ random_seed=i) for i in range(BATCH_SIZE)] env = environments[0] -torch.autograd.set_detect_anomaly(True) # After training we want to render the results so we also load a renderer # Add some variables to keep track of the progress -current_score = current_steps = current_collisions = current_done = mean_score = mean_steps = mean_collisions = mean_done = current_taken = mean_taken = None +current_score = current_collisions = current_done = mean_score = mean_collisions = mean_done = current_taken = mean_taken = current_finished = mean_finished = None agent_action_buffer = [] start_time = time.time() @@ -158,8 +157,8 @@ def normalize(observation, target_tensor): episode += 1 agent.reset() obs, info = zip(*[env.reset() for env in environments]) - - score, steps_taken, collision = 0, 0, False + episode_start = time.time() + score, collision = 0, False agent_count = len(obs[0]) if flags.global_environment: agent_obs = torch.as_tensor(obs, dtype=torch.float, device=device) @@ -171,29 +170,16 @@ def normalize(observation, target_tensor): agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] # Run an episode - city_count = (env.width * env.height)//300 - max_steps = int(8 * (env.width + env.height + agent_count/city_count)) + city_count = (env.width * env.height) // 300 + max_steps = int(8 * (env.width + env.height + agent_count / city_count)) for step in range(max_steps): - update_values = [[False] * agent_count for _ in range(BATCH_SIZE)] - action_dict = [{} for _ in range(BATCH_SIZE)] if flags.global_environment: input_tensor = torch.cat([agent_obs_buffer, agent_obs], -1) input_tensor.transpose_(1, -1) else: input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) - if any(any(inf['action_required']) for inf in info): - ret_action = agent.multi_act(input_tensor) - else: - ret_action = update_values - for idx, act_list in enumerate(ret_action): - for sub_idx, act in enumerate(act_list): - if info[idx]['action_required'][sub_idx]: - action_dict[idx][sub_idx] = act - # action_dict[a] = np.random.randint(5) - update_values[idx][sub_idx] = True - steps_taken += 1 - else: - action_dict[idx][sub_idx] = 0 + ret_action = agent.multi_act(input_tensor) + action_dict = [dict(enumerate(act_list)) for act_list in ret_action] # Environment step obs, rewards, done, info = tuple(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) @@ -233,20 +219,22 @@ def normalize(observation, target_tensor): sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count), episode) current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) - current_steps, mean_steps = get_means(current_steps, mean_steps, steps_taken / BATCH_SIZE / agent_count, episode) current_taken, mean_taken = get_means(current_taken, mean_taken, step, episode) current_done, mean_done = get_means(current_done, mean_done, sum(d[i] for d in done for i in range(agent_count)) / ( - BATCH_SIZE * agent_count), + BATCH_SIZE * agent_count), episode) + current_finished, mean_finished = get_means(current_finished, mean_finished, + sum(d['__all__'] for d in done) / BATCH_SIZE, + episode) print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6} - Agents: {agent_count:>3}' f' | Score: {current_score:.4f}, {mean_score:.4f}' - f' | Agent-Steps: {current_steps:6.1f}, {mean_steps:6.1f}' f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' - f' | Collisions: {100 * current_collisions:5.2f}%, {100 * mean_collisions:5.2f}%' - f' | Finished: {100 * current_done:5.2f}%, {100 * mean_done:5.2f}%' - f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):.4f}s', end='') + f' | Collisions: {100 * current_collisions:6.2f}%, {100 * mean_collisions:6.2f}%' + f' | Agent Done: {100 * current_done:6.2f}%, {100 * mean_done:6.2f}%' + f' | Finished: {100 * current_finished:6.2f}%, {100 * mean_finished:6.2f}%' + f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):7.4f}s - Took: {time.time()-episode_start:7.1f}', end='') print("") if flags.train: From ae209512f3dad183f1bba520c4fc419fb7542573 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Tue, 21 Jul 2020 03:06:21 +0200 Subject: [PATCH 59/75] feat(train): remove cross-batch mean calculation --- src/train.py | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/src/train.py b/src/train.py index ecbe527..f34ce7f 100644 --- a/src/train.py +++ b/src/train.py @@ -95,7 +95,6 @@ # After training we want to render the results so we also load a renderer # Add some variables to keep track of the progress -current_score = current_collisions = current_done = mean_score = mean_collisions = mean_done = current_taken = mean_taken = current_finished = mean_finished = None agent_action_buffer = [] start_time = time.time() @@ -126,10 +125,6 @@ def is_collision(a, i): return False -def get_means(x, y, c, s): - return c if x is None else (x * 3 + c) / 4, c if y is None else (y * (s - 1) + c) / s - - chunk_size = (BATCH_SIZE + 1) // flags.threads @@ -215,26 +210,14 @@ def normalize(observation, target_tensor): # render() # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break - current_collisions, mean_collisions = get_means(current_collisions, mean_collisions, - sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count), - episode) - current_score, mean_score = get_means(current_score, mean_score, score / max_steps, episode) - current_taken, mean_taken = get_means(current_taken, mean_taken, step, episode) - current_done, mean_done = get_means(current_done, mean_done, - sum(d[i] for d in done for i in range(agent_count)) / ( - BATCH_SIZE * agent_count), - episode) - current_finished, mean_finished = get_means(current_finished, mean_finished, - sum(d['__all__'] for d in done) / BATCH_SIZE, - episode) - - print(f'\rBatch {episode:>4} - Episode {BATCH_SIZE * episode:>6} - Agents: {agent_count:>3}' - f' | Score: {current_score:.4f}, {mean_score:.4f}' - f' | Steps Taken: {current_taken:6.1f}, {mean_taken:6.1f}' - f' | Collisions: {100 * current_collisions:6.2f}%, {100 * mean_collisions:6.2f}%' - f' | Agent Done: {100 * current_done:6.2f}%, {100 * mean_done:6.2f}%' - f' | Finished: {100 * current_finished:6.2f}%, {100 * mean_finished:6.2f}%' - f' | Episode/s: {BATCH_SIZE * episode / (time.time() - start_time):7.4f}s - Took: {time.time()-episode_start:7.1f}', end='') + + print(f'\rBatch{episode:>3} - Episode{BATCH_SIZE * episode:>5} - Agents:{agent_count:>3}' + f' | Score: {score / max_steps:.4f}' + f' | Steps: {step:4.0f}' + f' | Collisions: {100 * sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count):6.2f}%' + f' | Done: {100 * sum(d[i] for d in done for i in range(agent_count)) / (BATCH_SIZE * agent_count):6.2f}%' + f' | Finished: {100 * sum(d["__all__"] for d in done) / BATCH_SIZE:6.2f}%' + f' | Took: {time.time()-episode_start:5.0f}s', end='') print("") if flags.train: From 696da8592d9d8aae731904b087344489c37e0868 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 24 Jul 2020 10:29:10 +0200 Subject: [PATCH 60/75] feat: add global-state model (globalobs + agent-state) --- src/agent.py | 42 ++++++++---- src/cythonize.sh | 8 +-- src/generate_railways.py | 19 +++--- src/model.py | 140 ++++++++++++++++++++++++++++---------- src/observation_utils.pyx | 57 +++++++++++++--- src/railway_utils.py | 24 ++++--- src/train.py | 128 +++++++++++++++++++++++----------- 7 files changed, 298 insertions(+), 120 deletions(-) diff --git a/src/agent.py b/src/agent.py index d28efca..b22a63a 100644 --- a/src/agent.py +++ b/src/agent.py @@ -6,9 +6,9 @@ from torch_optimizer import Yogi as Optimizer try: - from .model import QNetwork, ConvNetwork, init + from .model import QNetwork, ConvNetwork, init, GlobalStateNetwork except: - from model import QNetwork, ConvNetwork, init + from model import QNetwork, ConvNetwork, init, GlobalStateNetwork import os BUFFER_SIZE = 500_000 @@ -25,21 +25,24 @@ class Agent: - def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, - use_global=False, softmax=True, debug=True): + def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, decoder_depth, + model_type=0, softmax=True, debug=True): self.action_size = action_size # Q-Network - if use_global: + if model_type == 1: # Global/Local network = ConvNetwork - else: + elif model_type == 0: # Tree network = QNetwork + else: # Global State + network = GlobalStateNetwork self.policy = network(state_size, action_size, hidden_factor, model_depth, kernel_size, squeeze_heads, + decoder_depth, softmax=softmax).to(device) self.old_policy = network(state_size, action_size, @@ -47,6 +50,7 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s model_depth, kernel_size, squeeze_heads, + decoder_depth, softmax=softmax, debug=False).to(device) if debug: @@ -74,13 +78,17 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s self.t_step = 0 def reset(self): - self.finished = False + self.policy.reset_cache() + self.old_policy.reset_cache() def multi_act(self, state): - state = state.to(device) + if isinstance(state, tuple): + state = tuple(s.to(device) for s in state) + elif isinstance(state, torch.Tensor): + state = (state.to(device),) self.policy.eval() with torch.no_grad(): - action_values = self.policy(state) + action_values = self.policy(*state) # Epsilon-greedy action selection return action_values.argmax(1).detach().cpu().numpy() @@ -101,7 +109,12 @@ def step(self, state, action, agent_done, collision, step_reward=0, collision_re else (collision_reward if c else step_reward) for ad, c in zip(ad_batch, c_batch)] for ad_batch, c_batch in zip(ad_step, c_step)] for ad_step, c_step in zip(self.stack[2], self.stack[3])]).flatten(0, 1).to(device) - state = torch.cat(self.stack[0], 0).to(device) + state = self.stack[0] + if isinstance(state[0], tuple): + state = zip(*state) + state = tuple(torch.cat(st, 0).to(device) for st in state) + elif isinstance(state[0], torch.Tensor): + state = (torch.cat(state, 0).to(device),) self.stack = [[] for _ in range(4)] self.learn(state, action, reward) @@ -110,11 +123,12 @@ def learn(self, states, actions, rewards): actions.unsqueeze_(1) with torch.no_grad(): - states_clone = states.clone() - states_clone.requires_grad_(False) - old_responsible_outputs = self.old_policy(states_clone).gather(1, actions) + states_clone = tuple(st.clone() for st in states) + for st in states: + st.requires_grad_(False) + old_responsible_outputs = self.old_policy(*states_clone).gather(1, actions) old_responsible_outputs.detach_() - responsible_outputs = self.policy(states).gather(1, actions) + responsible_outputs = self.policy(*states).gather(1, actions) ratio = responsible_outputs / (old_responsible_outputs + 1e-5) ratio.squeeze_(1) clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) diff --git a/src/cythonize.sh b/src/cythonize.sh index 58e2047..f44bc17 100644 --- a/src/cythonize.sh +++ b/src/cythonize.sh @@ -1,13 +1,13 @@ function compile { file=${1} cython "$file.pyx" -3 -Wextra -D - cmd="gcc-9 $file.c `python3-config --cflags --ldflags --includes --libs` -I`python -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" - echo "Executing $cmd" - $cmd + flags="$file.c `python3-config --cflags --ldflags --includes --libs` -I`python -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" + echo "Executing gcc with $flags" + (gcc-9 $flags) || (gcc-7 $flags) echo "Testing compilation.." python3 -c "import $file" echo } compile observation_utils -compile rail_env \ No newline at end of file +compile rail_env diff --git a/src/generate_railways.py b/src/generate_railways.py index c65e0f6..0034d91 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -16,16 +16,19 @@ parser = argparse.ArgumentParser(description="Train an agent in the flatland environment") parser.add_argument("--width", type=int, default=35, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--factor", type=int, default=2, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--base", type=float, default=1.1, help="Decay factor for epsilon-greedy exploration") flags = parser.parse_args() -width = flags.width -rail_generator, schedule_generator = create_random_railways(flags.width) +rail_generator, schedule_generator = create_random_railways(flags.width, flags.base, flags.factor) # Load in any existing railways for this map size so we don't overwrite them +network = project_root / f'railroads/rail_networks_{flags.width}_{flags.factor}.pkl' +sched = project_root / f'railroads/schedules_{flags.width}_{flags.factor}.pkl' try: - with open(project_root / f'railroads/rail_networks_{width}.pkl', 'rb') as file: + with open(network, 'rb') as file: rail_networks = pickle.load(file) - with open(project_root / f'railroads/schedules_{width}.pkl', 'rb') as file: + with open(sched, 'rb') as file: schedules = pickle.load(file) print(f"Loading {len(rail_networks)} railways...") except: @@ -34,7 +37,7 @@ def do(schedules: list, rail_networks: list): for _ in range(100): - map, info = rail_generator(width, 1, 1, num_resets=0, np_random=np.random) + map, info = rail_generator(flags.width, 1, 1, num_resets=0, np_random=np.random) schedule = schedule_generator(map, 1, info['agents_hints'], num_resets=0, np_random=np.random) rail_networks.append((map, info)) schedules.append(schedule) @@ -43,10 +46,10 @@ def do(schedules: list, rail_networks: list): for _ in tqdm(range(500), ncols=150, leave=False): do(schedules, rail_networks) - with open(project_root / f'railroads/rail_networks_{width}.pkl', 'wb') as file: - pickle.dump(schedules, file, protocol=4) - with open(project_root / f'railroads/schedules_{width}.pkl', 'wb') as file: + with open(network, 'wb') as file: pickle.dump(rail_networks, file, protocol=4) + with open(sched, 'wb') as file: + pickle.dump(schedules, file, protocol=4) print(f"Saved {len(rail_networks)} railways") print("Done") diff --git a/src/model.py b/src/model.py index 36585a2..12e69ed 100644 --- a/src/model.py +++ b/src/model.py @@ -21,7 +21,7 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: class SeparableConvolution(torch.nn.Module): def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, - bias=False, dim=1, stride=1, dropout=0.25): + bias=False, dim=1, stride: typing.Union[int, tuple] = 1, dropout=0.25): super(SeparableConvolution, self).__init__() self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else any(k > 1 for k in kernel_size) conv = getattr(torch.nn, f'Conv{dim}d') @@ -60,36 +60,48 @@ def __repr__(self): class BasicBlock(torch.nn.Module): - def __init__(self, in_features, out_features, stride, init_norm=False, message_box=None): + def __init__(self, in_features, out_features, stride, init_norm=False, message_box=None, double=True, + agent_dim=True): super(BasicBlock, self).__init__() self.activate = init_norm - self.init_norm = torch.nn.InstanceNorm3d(in_features, affine=True) if init_norm else nothing - self.init_conv = SeparableConvolution(in_features, out_features, (3, 3, 1), (1, 1, 0), - stride=(stride, stride, 1), dim=3) - self.mid_norm = torch.nn.InstanceNorm3d(out_features, affine=True) - self.end_conv = SeparableConvolution(out_features, out_features, (3, 3, 1), (1, 1, 0), dim=3) + self.double = double + self.agent_dim = agent_dim + dim = 2 + agent_dim + norm = getattr(torch.nn, f'InstanceNorm{dim}d') + kernel = (3, 3) + ((1,) if agent_dim else ()) + pad = (1, 1) + ((0,) if agent_dim else ()) + stride = (stride, stride) + ((1,) if agent_dim else ()) + self.init_norm = norm(in_features, affine=True) if init_norm else nothing + self.init_conv = SeparableConvolution(in_features, out_features, kernel, pad, stride=stride, dim=dim) + if double: + self.mid_norm = norm(out_features, affine=True) + self.end_conv = SeparableConvolution(out_features, out_features, kernel, pad, dim=dim) + else: + self.mid_norm = nothing + self.end_conv = nothing self.shortcut = torch.nn.Sequential() - if stride > 1: - self.shortcut.add_module("1", torch.nn.AvgPool3d((stride, stride, 1), padding=(1,1,0))) + if stride[0] > 1: + self.shortcut.add_module("1", getattr(torch.nn, f"MaxPool{dim}d")(kernel, stride, padding=pad)) if in_features != out_features: - self.shortcut.add_module("2", torch.nn.Conv3d(in_features, out_features, 1)) + self.shortcut.add_module("2", getattr(torch.nn, f"Conv{dim}d")(in_features, out_features, 1)) self.message_box = int(out_features ** 0.5) if message_box is None else message_box def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out = self.init_conv(fn_input if self.activate is None else mish(self.init_norm(fn_input))) - out = mish(self.mid_norm(out)) - out: torch.Tensor = self.end_conv(out) - out[:, :self.message_box] = out[:, :self.message_box].mean(-1, - keepdim=True).expand(-1, -1, -1, -1, - fn_input.size(-1)) + if self.double: + out = self.end_conv(mish(self.mid_norm(out))) + if self.agent_dim and self.message_box > 0: + out[:, :self.message_box] = out[:, :self.message_box].mean(-1, + keepdim=True).expand(-1, -1, -1, -1, + fn_input.size(-1)) srt = self.shortcut(fn_input) out = out + srt return out class ConvNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, cat=True, - debug=True, softmax=False): + def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=True, debug=True, softmax=False): super(ConvNetwork, self).__init__() _ = state_size state_size = 2 * 21 @@ -99,9 +111,7 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size init_norm=bool(i)) for i in range(depth)]) self.init_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) - self.linear0 = torch.nn.Conv1d(hidden_size, hidden_size, 1, bias=False) - self.mid_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) - self.linear1 = torch.nn.Conv1d(hidden_size, action_size, 1) + self.linear = torch.nn.Conv1d(hidden_size, action_size, 1) self.softmax = softmax def forward(self, fn_input: torch.Tensor) -> torch.Tensor: @@ -109,7 +119,58 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: for module in self.net: out = module(out) out = out.mean((2, 3)) - out = self.linear1(mish(self.mid_norm(self.linear0(mish(self.init_norm(out)))))) + out = self.linear(mish(self.init_norm(out))) + return torch.nn.functional.softmax(out, 1) if self.softmax else out + + @torch.jit.export + def reset_cache(self): + pass + + +class GlobalStateNetwork(torch.nn.Module): + def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, + decoder_depth=1 , cat=True, debug=True, softmax=False): + super(GlobalStateNetwork, self).__init__() + _ = state_size + _ = kernel_size + _ = squeeze_heads + _ = cat + _ = debug + global_state_size = 2 * 16 + agent_state_size = 2 * 13 + + self.net = torch.nn.Sequential(*[BasicBlock(global_state_size if not i else hidden_size, + hidden_size, + 2, + init_norm=bool(i), + message_box=0, + double=False, + agent_dim=False) + for i in range(depth)]) + self.decoder = torch.nn.Sequential(*[layer + for i in range(decoder_depth - 1) + for layer in (torch.nn.Conv1d(hidden_size + (0 if i else agent_state_size), + hidden_size, + 1, + bias=False), + torch.nn.InstanceNorm1d(hidden_size, affine=True), + Mish())], + torch.nn.Conv1d(hidden_size, action_size, 1)) + self.softmax = softmax + + self.register_buffer("base_zero", torch.zeros(1)) + self.encoding_cache = self.base_zero + + @torch.jit.export + def reset_cache(self): + self.encoding_cache = self.base_zero + + def forward(self, state, rail) -> torch.Tensor: + if torch.equal(self.encoding_cache, self.base_zero): + self.encoding_cache = self.net(rail) + self.encoding_cache = self.encoding_cache.mean((2, 3), keepdim=True).squeeze(-1) + inp = torch.cat([self.encoding_cache.clone().expand(-1, -1, state.size(-1)), state], 1) + out = self.decoder(inp) return torch.nn.functional.softmax(out, 1) if self.softmax else out @@ -128,7 +189,7 @@ class Residual(torch.nn.Module): def __init__(self, features): super(Residual, self).__init__() self.norm = torch.nn.BatchNorm1d(features) - self.conv = torch.nn.Conv1d(features, 2 * features) + self.conv = torch.nn.Conv1d(features, 2 * features, 1) def forward(self, fn_input: torch.Tensor) -> torch.Tensor: out, exc = self.conv(mish(self.norm(fn_input))).chunk(2, 1) @@ -138,18 +199,25 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: return fn_input + out -def QNetwork(state_size, action_size, hidden_factor=16, depth=4, kernel_size=7, squeeze_heads=4, cat=False, - debug=True): - model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 11 * hidden_factor, 1, groups=11, bias=False), - Residual(11 * hidden_factor), - torch.nn.BatchNorm1d(11 * hidden_factor), - Mish(), - torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) - print(model) - - model.apply(init) - try: - model = torch.jit.script(model) - except TypeError: +class QNetwork(torch.nn.Sequential): + def __init__(self, state_size, action_size, hidden_factor=16, depth=4, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=False, debug=True): + super(QNetwork, self).__init__() + _ = depth + _ = kernel_size + _ = squeeze_heads + _ = cat + _ = debug + _ = decoder_depth + self.model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 11 * hidden_factor, 1, groups=11, bias=False), + Residual(11 * hidden_factor), + torch.nn.BatchNorm1d(11 * hidden_factor), + Mish(), + torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) + + @torch.jit.export + def reset_cache(self): pass - return model + + def forward(self, *args): + return self.model(*args) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index bc5c9fa..61d10db 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -32,6 +32,9 @@ cpdef int get_direction(int orientation, int action): else: return orientation +cpdef set_array(tuple data, int width, cnp.ndarray arr, int base_index, int start_index): + arr[start_index:start_index + 4, base_index] = data + (data[0] / width, data[1] / width) + cdef class RailNode: cdef public dict edges cdef public tuple position @@ -74,15 +77,18 @@ class GlobalObsForRailEnv(ObservationBuilder): def reset(self): if self._custom_rail_obs is None: self._custom_rail_obs = np.zeros((1, self.env.height + 2 * self.size, self.env.width + 2 * self.size, 16)) - - self._custom_rail_obs[0, self.size:-self.size, self.size:-self.size] = np.array([[[[1 if digit == '1' else 0 - for digit in - f'{self.env.rail.get_full_transitions(i, j):016b}'] - for j in - range(self.env.width)] - for i in - range(self.env.height)]], - dtype=self.data_type) + cdef cnp.ndarray out = np.array([[[[1 if digit == '1' else 0 + for digit in + f'{self.env.rail.get_full_transitions(i, j):016b}'] + for j in + range(self.env.width)] + for i in + range(self.env.height)]], + dtype=self.data_type) + if self.size > 0: + self._custom_rail_obs[0, self.size:-self.size, self.size:-self.size] = out + else: + self._custom_rail_obs = out def get_many(self, list trash): cdef int agent_count = len(self.env.agents) @@ -184,6 +190,39 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): return obs_agents_state return dict(enumerate(obs_agents_state)) + +class GlobalStateObs(GlobalObsForRailEnv): + def __init__(self, return_array=True): + super(GlobalStateObs, self).__init__(return_array=return_array) + def get_many(self, list trash): + cdef int agent_count = len(self.env.agents) + cdef cnp.ndarray obs_agents_state = np.zeros((13, agent_count), dtype=np.float32) + cdef int i, agent_id + cdef tuple agent_virtual_position + + for agent_id, agent in enumerate(self.env.agents): + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: # Done+Removed + continue + + obs_agents_state[0, agent_id] = agent.direction + obs_agents_state[1, agent_id] = agent.malfunction_data['malfunction'] + obs_agents_state[2, agent_id] = agent.speed_data['speed'] + obs_agents_state[3, agent_id] = agent.status + obs_agents_state[4, agent_id] = agent.moving + set_array(agent_virtual_position, self.env.width, obs_agents_state, agent_id, 5) + set_array(agent.target, self.env.width, obs_agents_state, agent_id, 9) + + if self.return_array: + return obs_agents_state, self._custom_rail_obs[0] + return dict(enumerate(obs_agents_state)) + + class TreeObservation(ObservationBuilder): def __init__(self, max_depth): super().__init__() diff --git a/src/railway_utils.py b/src/railway_utils.py index 182e8aa..56bc8fe 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -49,7 +49,7 @@ def __call__(self, *args, **kwargs): class RailGenerator: - def __init__(self, width=35, base=1.5): + def __init__(self, width=35, base=1.5, factor=2): self.rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=max(2, width ** 2 // 300), max_rails_between_cities=2, @@ -58,31 +58,33 @@ def __init__(self, width=35, base=1.5): self.top_idx = 0 self.width = width self.base = base + self.factor = factor def __next__(self): self.sub_idx += 1 if self.sub_idx == BATCH_SIZE: self.sub_idx = 0 self.top_idx += 1 - return self.rail_generator(self.width, self.width, int(2 * self.base ** self.top_idx), np_random=np.random) + return self.rail_generator(self.width, self.width, int(self.factor * self.base ** self.top_idx), np_random=np.random) def __call__(self, *args, **kwargs): return next(self) class ScheduleGenerator: - def __init__(self, base=1.5): + def __init__(self, base=1.5, factor=2): self.schedule_generator = sparse_schedule_generator({1.: 1.}) self.sub_idx = 0 self.top_idx = 0 self.base = base + self.factor = factor def __next__(self, rail, hints): if self.sub_idx == BATCH_SIZE: self.sub_idx = 0 self.top_idx += 1 self.sub_idx += 1 - return self.schedule_generator(rail, int(2 * self.base ** self.top_idx), hints, np_random=np.random) + return self.schedule_generator(rail, int(self.factor * self.base ** self.top_idx), hints, np_random=np.random) def __call__(self, rail, _, hints, *args, **kwargs): return self.__next__(rail, hints) @@ -92,17 +94,17 @@ def __call__(self, rail, _, hints, *args, **kwargs): def load_precomputed_railways(project_root, start_index, big=False): prefix = os.path.join(project_root, 'railroads') if big: - suffix = f'_110.pkl' + suffix = f'_35_6.pkl' else: suffix = f'_50.pkl' - sched = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) - rail = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) - #if big: - # sched, rail = rail, sched + rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + if not big: + sched, rail = rail, sched print(f"Working on {len(rail)} tracks") return rail, sched # Helper function to generate railways on the fly -def create_random_railways(width, base=1.1): - return RailGenerator(width=width, base=base), ScheduleGenerator(base=base) +def create_random_railways(width, base=1.1, factor=2): + return RailGenerator(width=width, base=base, factor=factor), ScheduleGenerator(base=base, factor=factor) diff --git a/src/train.py b/src/train.py index f34ce7f..b756318 100644 --- a/src/train.py +++ b/src/train.py @@ -1,12 +1,17 @@ import argparse +import copy import time from itertools import zip_longest from pathlib import Path - +import numpy as np import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +# from flatland.utils.rendertools import RenderTool, AgentRenderVariant from pathos import multiprocessing +# import cv2 + + torch.jit.optimized_execution(True) positive_infinity = int(1e5) @@ -16,13 +21,15 @@ from .rail_env import RailEnv from .agent import Agent as DQN_Agent, device, BATCH_SIZE from .normalize_output_data import wrap - from .observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv + from .observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv, \ + GlobalStateObs from .railway_utils import load_precomputed_railways, create_random_railways except: from rail_env import RailEnv from agent import Agent as DQN_Agent, device, BATCH_SIZE from normalize_output_data import wrap - from observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv + from observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv, \ + GlobalStateObs from railway_utils import load_precomputed_railways, create_random_railways project_root = Path(__file__).resolve().parent.parent @@ -42,18 +49,36 @@ parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") parser.add_argument("--observation-size", type=int, default=4, help="Depth of the observation tree") +parser.add_argument("--decoder-depth", type=int, default=1, help="Depth of the observation tree") # Training parameters parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") parser.add_argument("--collision-reward", type=float, default=-2, help="Depth of the observation tree") -parser.add_argument("--global-environment", type=boolean, default=True, help="Depth of the observation tree") -parser.add_argument("--local-environment", type=boolean, default=True, help="Depth of the observation tree") +parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") +parser.add_argument("--local-environment", type=boolean, default=False, help="Depth of the observation tree") +parser.add_argument("--state-environment", type=boolean, default=True, help="Depth of the observation tree") parser.add_argument("--threads", type=int, default=1, help="Depth of the observation tree") flags = parser.parse_args() -if flags.local_environment: - flags.global_environment = True +if sum((flags.global_environment, flags.local_environment, flags.state_environment)) > 1: + print("Too many environment flags used. Priority is global > local > state.") + +if flags.global_environment: + model_type = 1 + env = GlobalObsForRailEnv() +elif flags.local_environment: + model_type = 1 + env = LocalObsForRailEnv(flags.observation_size) +elif flags.state_environment: + model_type = 2 + env = GlobalStateObs() +else: + model_type = 0 + env = TreeObservation(flags.tree_depth) + +if model_type not in (0, 1, 2): + raise UserWarning("Unknown model type") # Seeded RNG so we can replicate our results @@ -70,7 +95,8 @@ flags.hidden_factor, flags.kernel_size, flags.squeeze_heads, - flags.global_environment) + flags.decoder_depth, + model_type) if flags.load_model: start, = agent.load(project_root / 'checkpoints', 0) else: @@ -79,16 +105,12 @@ rail_generator, schedule_generator = load_precomputed_railways(project_root, start * BATCH_SIZE) # Create the Flatland environment -environments = [RailEnv(width=40, height=40, number_of_agents=1, +environments = [RailEnv(width=50, height=50, number_of_agents=1, rail_generator=rail_generator, schedule_generator=schedule_generator, malfunction_generator_and_process_data=malfunction_from_params( MalfunctionParameters(1 / 500, 20, 50)), - obs_builder_object=((LocalObsForRailEnv(flags.observation_size) - if flags.local_environment - else GlobalObsForRailEnv) - if flags.global_environment - else TreeObservation(max_depth=flags.tree_depth)), + obs_builder_object=copy.deepcopy(env), random_seed=i) for i in range(BATCH_SIZE)] env = environments[0] @@ -102,13 +124,13 @@ # Helper function to detect collisions ACTIONS = {0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S'} -if flags.global_environment: +if model_type in (1, 2): def is_collision(a, i): own_agent = environments[i].agents[a] return any(own_agent.position == agent.position for agent_id, agent in enumerate(environments[i].agents) if agent_id != a) -else: +else: #model_type == 0 def is_collision(a, i): if obs[i][a] is None: return False is_junction = not isinstance(obs[i][a].childs['L'], float) or not isinstance(obs[i][a].childs['R'], float) @@ -124,7 +146,6 @@ def is_collision(a, i): else: return False - chunk_size = (BATCH_SIZE + 1) // flags.threads @@ -143,8 +164,16 @@ def normalize(observation, target_tensor): normalize_observation(observation, flags.tree_depth, target_tensor, 0) wrap(target_tensor) +def as_tensor(array_list): + return torch.as_tensor(np.stack(array_list, 0), dtype=torch.float, device=device) + episode = 0 POOL = multiprocessing.Pool() +# env_renderer = None +# def render(): +# env_renderer.render_env(show_observations=False) +# cv2.imshow('Render', cv2.cvtColor(env_renderer.get_image(), cv2.COLOR_BGR2RGB)) +# cv2.waitKey(120) # Main training loop episode = start @@ -152,46 +181,60 @@ def normalize(observation, target_tensor): episode += 1 agent.reset() obs, info = zip(*[env.reset() for env in environments]) + # env_renderer = RenderTool(environments[0], gl="PILSVG", screen_width=1000, screen_height=1000, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) + # env_renderer.reset() episode_start = time.time() score, collision = 0, False agent_count = len(obs[0]) - if flags.global_environment: - agent_obs = torch.as_tensor(obs, dtype=torch.float, device=device) - else: + if model_type == 1: + agent_obs = as_tensor(obs) + agent_obs_buffer = agent_obs.clone() + elif model_type == 0: agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) + agent_obs_buffer = agent_obs.clone() + else: # model_type == 2 + rail, obs = zip(*obs) + agent_obs = as_tensor(rail), as_tensor(obs) + agent_obs_buffer = agent_obs[0].clone(), agent_obs[1].clone() - agent_obs_buffer = agent_obs.clone() agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] # Run an episode city_count = (env.width * env.height) // 300 - max_steps = int(8 * (env.width + env.height + agent_count / city_count)) + max_steps = int(8 * (env.width + env.height + agent_count / city_count)) - 10 + # -10 = have some distance to the "real" max steps + done = [[False]] + _done = [[False]] for step in range(max_steps): - if flags.global_environment: + done = _done + if model_type == 1: input_tensor = torch.cat([agent_obs_buffer, agent_obs], -1) - input_tensor.transpose_(1, -1) - else: + elif model_type == 0: input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) + else: # model_type == 2 + input_tensor = (torch.cat((agent_obs_buffer[0], agent_obs[0]), 1), + torch.cat((agent_obs_buffer[1], agent_obs[1]), -1)) + input_tensor[1].transpose_(1, -1) + ret_action = agent.multi_act(input_tensor) action_dict = [dict(enumerate(act_list)) for act_list in ret_action] # Environment step - obs, rewards, done, info = tuple(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) + obs, rewards, _done, info = tuple(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) score += sum(i for r in rewards for i in r.values()) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion - all_done = (step == (max_steps - 1)) or all(d['__all__'] for d in done) + all_done = (step == (max_steps - 1)) or all(d['__all__'] for d in _done) collision = [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)] # Update replay buffer and train agent if flags.train: agent.step(input_tensor, agent_action_buffer, - done, + _done, collision, flags.step_reward, flags.collision_reward) - agent_obs_buffer = agent_obs.clone() for idx, act in enumerate(action_dict): for key, value in act.items(): agent_action_buffer[idx][key] = value @@ -199,26 +242,35 @@ def normalize(observation, target_tensor): if all_done: break - if flags.global_environment: - agent_obs = torch.as_tensor(obs, dtype=torch.float, device=device) - else: + if model_type == 1: + agent_obs_buffer = agent_obs.clone() + agent_obs = as_tensor(obs) + elif model_type == 0: + agent_obs_buffer = agent_obs.clone() + agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) normalize(obs, agent_obs) + else: # model_type == 2 + rail, obs = zip(*obs) + agent_obs_buffer = agent_obs[0].clone(), agent_obs[1].clone() + agent_obs = as_tensor(rail), as_tensor(obs) # Render # if flags.render_interval and episode % flags.render_interval == 0: # if collision and all(agent.position for agent in env.agents): - # render() + # if step % 2 == 1: + # print([a.position for a in environments[0].agents]) + # render() # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break print(f'\rBatch{episode:>3} - Episode{BATCH_SIZE * episode:>5} - Agents:{agent_count:>3}' f' | Score: {score / max_steps:.4f}' - f' | Steps: {step:4.0f}' + f' | Steps: {step:4.0f}' f' | Collisions: {100 * sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count):6.2f}%' - f' | Done: {100 * sum(d[i] for d in done for i in range(agent_count)) / (BATCH_SIZE * agent_count):6.2f}%' + f' | Done: {100 * sum(d[i] for d in done for i in range(agent_count)) / (BATCH_SIZE * agent_count):6.2f}%' f' | Finished: {100 * sum(d["__all__"] for d in done) / BATCH_SIZE:6.2f}%' - f' | Took: {time.time()-episode_start:5.0f}s', end='') + f' | Took: {time.time() - episode_start:5.0f}s', end='') print("") - if flags.train: - agent.save(project_root / 'checkpoints', episode) +# if flags.train: +# agent.save(project_root / 'checkpoints', episode) From c10860b496800c53d86ab91e2016186302245943 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 24 Jul 2020 11:02:23 +0200 Subject: [PATCH 61/75] perf(agent): enforce sparse gradient (experiment) --- src/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent.py b/src/agent.py index b22a63a..8dea564 100644 --- a/src/agent.py +++ b/src/agent.py @@ -132,7 +132,7 @@ def learn(self, states, actions, rewards): ratio = responsible_outputs / (old_responsible_outputs + 1e-5) ratio.squeeze_(1) clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = -torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).mean() + loss = -torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).max() self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer.zero_grad() loss.backward() From c46b1c5d5baf6719a454a1c7e83f63ae38c3aa9c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 24 Jul 2020 11:02:50 +0200 Subject: [PATCH 62/75] feat(interface): move depth-1 --- src/model.py | 6 ++++-- src/train.py | 8 +++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/model.py b/src/model.py index 12e69ed..159b968 100644 --- a/src/model.py +++ b/src/model.py @@ -148,14 +148,16 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size agent_dim=False) for i in range(depth)]) self.decoder = torch.nn.Sequential(*[layer - for i in range(decoder_depth - 1) + for i in range(decoder_depth) for layer in (torch.nn.Conv1d(hidden_size + (0 if i else agent_state_size), hidden_size, 1, bias=False), torch.nn.InstanceNorm1d(hidden_size, affine=True), Mish())], - torch.nn.Conv1d(hidden_size, action_size, 1)) + torch.nn.Conv1d(hidden_size + (0 if decoder_depth else agent_state_size), + action_size, + 1)) self.softmax = softmax self.register_buffer("base_zero", torch.zeros(1)) diff --git a/src/train.py b/src/train.py index b756318..1b95f0d 100644 --- a/src/train.py +++ b/src/train.py @@ -44,7 +44,7 @@ # Environment parameters parser.add_argument("--tree-depth", type=int, default=2, help="Depth of the observation tree") -parser.add_argument("--model-depth", type=int, default=3, help="Depth of the observation tree") +parser.add_argument("--model-depth", type=int, default=5, help="Depth of the observation tree") parser.add_argument("--hidden-factor", type=int, default=48, help="Depth of the observation tree") parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") @@ -185,7 +185,7 @@ def as_tensor(array_list): # env_renderer.reset() episode_start = time.time() score, collision = 0, False - agent_count = len(obs[0]) + agent_count = len(environments[0].agents) if model_type == 1: agent_obs = as_tensor(obs) agent_obs_buffer = agent_obs.clone() @@ -235,9 +235,7 @@ def as_tensor(array_list): collision, flags.step_reward, flags.collision_reward) - for idx, act in enumerate(action_dict): - for key, value in act.items(): - agent_action_buffer[idx][key] = value + agent_action_buffer = [[act[i] for i in range(agent_count)] for act in action_dict] if all_done: break From 6ee0008d53d92ad0ab6dff11d769b6e312f54197 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 07:09:06 +0200 Subject: [PATCH 63/75] feat(interface): move depth-1 --- src/model.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/model.py b/src/model.py index 159b968..bfd85ad 100644 --- a/src/model.py +++ b/src/model.py @@ -127,9 +127,25 @@ def reset_cache(self): pass +class DecoderBlock(torch.nn.Module): + def __init__(self, features, message_box=None): + super(DecoderBlock, self).__init__() + self.norm = torch.nn.InstanceNorm1d(features, affine=True) + self.conv = torch.nn.Conv1d(features, features, 1, bias=False) + self.message_box = int(features ** 0.5) if message_box is None else message_box + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + out = self.norm(fn_input) + out = mish(out) + out = self.conv(out) + if self.message_box > 0: + out[:, :self.message_box] = out[:, :self.message_box].mean(-1, keepdim=True).expand(-1, -1, fn_input.size(-1)) + return out + fn_input + + class GlobalStateNetwork(torch.nn.Module): def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, - decoder_depth=1 , cat=True, debug=True, softmax=False): + decoder_depth=1, cat=True, debug=True, softmax=False): super(GlobalStateNetwork, self).__init__() _ = state_size _ = kernel_size @@ -147,17 +163,10 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size double=False, agent_dim=False) for i in range(depth)]) - self.decoder = torch.nn.Sequential(*[layer - for i in range(decoder_depth) - for layer in (torch.nn.Conv1d(hidden_size + (0 if i else agent_state_size), - hidden_size, - 1, - bias=False), - torch.nn.InstanceNorm1d(hidden_size, affine=True), - Mish())], - torch.nn.Conv1d(hidden_size + (0 if decoder_depth else agent_state_size), - action_size, - 1)) + self.decoder = torch.nn.Sequential(torch.nn.Conv1d(hidden_size + agent_state_size, hidden_size, 1, bias=False), + *[DecoderBlock(hidden_size) for i in range(decoder_depth)], + torch.nn.InstanceNorm1d(hidden_size, affine=True), + torch.nn.Conv1d(hidden_size, action_size, 1)) self.softmax = softmax self.register_buffer("base_zero", torch.zeros(1)) @@ -169,9 +178,12 @@ def reset_cache(self): def forward(self, state, rail) -> torch.Tensor: if torch.equal(self.encoding_cache, self.base_zero): - self.encoding_cache = self.net(rail) - self.encoding_cache = self.encoding_cache.mean((2, 3), keepdim=True).squeeze(-1) - inp = torch.cat([self.encoding_cache.clone().expand(-1, -1, state.size(-1)), state], 1) + inp = self.net(rail) + inp = inp.mean((2, 3), keepdim=True).squeeze(-1) + self.encoding_cache = inp + else: + inp = self.encoding_cache + inp = torch.cat([inp.expand(-1, -1, state.size(-1)), state], 1) out = self.decoder(inp) return torch.nn.functional.softmax(out, 1) if self.softmax else out From 50503ea8297f47a43b5209cf13022d18c48ef86c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 10:31:52 +0200 Subject: [PATCH 64/75] perf(cython): add header instructions --- src/observation_utils.pyx | 32 +++++++++++++++++++------------- src/rail_env.pyx | 19 ++++++++++++------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx index 61d10db..e26249e 100644 --- a/src/observation_utils.pyx +++ b/src/observation_utils.pyx @@ -1,3 +1,10 @@ +#!python +#cython: boundscheck=False +#cython: initializedcheck=False +#cython: nonecheck=False +#cython: wraparound=False +#cython: cdivision=True + from collections import defaultdict cimport numpy as cnp @@ -26,7 +33,7 @@ cpdef str get_action(int orientation, int direction): cpdef int get_direction(int orientation, int action): if action == 1: - return (orientation + 4 - 1) % 4 + return (orientation - 1) % 4 elif action == 3: return (orientation + 1) % 4 else: @@ -86,7 +93,7 @@ class GlobalObsForRailEnv(ObservationBuilder): range(self.env.height)]], dtype=self.data_type) if self.size > 0: - self._custom_rail_obs[0, self.size:-self.size, self.size:-self.size] = out + self._custom_rail_obs[0, self.size:self.env.height - self.size, self.size:self.env.width - self.size] = out else: self._custom_rail_obs = out @@ -173,8 +180,6 @@ class LocalObsForRailEnv(GlobalObsForRailEnv): # second to fourth channel only if in the grid if other_agent.position is not None: - pos = (agent_id,) + other_agent.position - # second channel only for other agents if i != agent_id: obs_agents_state[agent_id, :, :, 1] = other_agent.direction obs_agents_state[agent_id, :, :, 2] = other_agent.malfunction_data['malfunction'] @@ -197,7 +202,7 @@ class GlobalStateObs(GlobalObsForRailEnv): def get_many(self, list trash): cdef int agent_count = len(self.env.agents) cdef cnp.ndarray obs_agents_state = np.zeros((13, agent_count), dtype=np.float32) - cdef int i, agent_id + cdef int agent_id cdef tuple agent_virtual_position for agent_id, agent in enumerate(self.env.agents): @@ -437,10 +442,11 @@ class TreeObservation(ObservationBuilder): cdef list path = list() + # Skip ahead until we get to a major node, logging any agents on the tracks along the way while True: path = self.edge_paths.get((node.position, direction), []) - orientation = path[-1][-1] if path else direction + orientation = path[len(path) - 1][len(path[len(path) - 1]) - 1] if path else direction dist = total_distance + edge_length key = (*node.position, direction) next_key = (*next_node.position, orientation) @@ -556,28 +562,28 @@ cpdef create_tree_features(node, int max_depth, list data): if node == negative_infinity or node == positive_infinity or node is None: data.append(ZERO_NODE.expand((4 ** (max_depth - current_depth + 1) - 1) // 3, -1)) else: - data.append(torch.FloatTensor(node[:-2]).view(1, 11)) + data.append(torch.FloatTensor(node[:11]).unsqueeze(0)) if node.childs: for direction in ACTIONS: nodes.append((node.childs[direction], current_depth + 1)) # Normalize a tree observation cpdef normalize_observation(tuple observations, int max_depth, shared_tensor, int starting_index): - cdef list data = [] + cdef list data = [[[] for _ in range(len(observations[0]))] for _ in range(len(observations))] cdef int i = 0 + cdef int sub = 0 for i, tree in enumerate(observations, 1): if tree is None: break - data.append([]) if isinstance(tree, dict): tree = tree.values() - for t in tree: - data[-1].append([]) + for sub, t in enumerate(tree): if isinstance(t, dict): for d in t.values(): - create_tree_features(d, max_depth, data[-1][-1]) + create_tree_features(d, max_depth, data[i][sub]) else: - create_tree_features(t, max_depth, data[-1][-1]) + create_tree_features(t, max_depth, data[i][sub]) + shared_tensor[starting_index:starting_index + i] = torch.stack([torch.stack([torch.cat(dat, 0) for dat in tree if dat != []], -1) diff --git a/src/rail_env.pyx b/src/rail_env.pyx index e4dcec1..abaeac4 100644 --- a/src/rail_env.pyx +++ b/src/rail_env.pyx @@ -1,9 +1,14 @@ +#!python +#cython: boundscheck=False +#cython: initializedcheck=False +#cython: nonecheck=False +#cython: wraparound=False """ Definition of the RailEnv environment. """ import random # TODO: _ this is a global method --> utils or remove later -from typing import List, NamedTuple, Optional, Dict +from typing import List, NamedTuple, Dict import msgpack_numpy as m import numpy as np @@ -12,7 +17,6 @@ from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import IntVector2D -from flatland.core.transition_map import GridTransitionMap # Need to use circular imports for persistence. from flatland.envs import malfunction_generators as mal_gen from flatland.envs import persistence @@ -174,7 +178,7 @@ class RailEnv(Environment): schedule_generator = sched_gen.random_schedule_generator() self.schedule_generator = schedule_generator - self.rail: Optional[GridTransitionMap] = None + self.rail = None self.width = width self.height = height @@ -185,7 +189,7 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder.set_env(self) - self._max_episode_steps: Optional[int] = None + self._max_episode_steps = 0 self._elapsed_steps = 0 self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) @@ -195,7 +199,7 @@ class RailEnv(Environment): self.dev_obs_dict = {} self.dev_pred_dict = {} - self.agents: List[EnvAgent] = [] + self.agents = [] self.number_of_agents = number_of_agents self.num_resets = 0 self.distance_map = DistanceMap(self.agents, self.height, self.width) @@ -209,7 +213,7 @@ class RailEnv(Environment): self.valid_positions = None # global numpy array of agents position, True means that there is an agent at that cell - self.agent_positions: np.ndarray = np.full((height, width), False) + self.agent_positions = np.full((height, width), False) # save episode timesteps ie agent positions, orientations. (not yet actions / observations) self.record_steps = record_steps # whether to save timesteps @@ -456,6 +460,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict + def _step_agent(self, int i_agent, int action): """ Performs a step and step, start and stop penalty on a single agent in the following sub steps: @@ -536,7 +541,7 @@ class RailEnv(Environment): else: # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, # try to keep moving forward! - if (action == MOVE_LEFT or action == MOVE_RIGHT): + if action in (MOVE_LEFT, MOVE_RIGHT): _, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(MOVE_FORWARD, agent) From 7a69bbf6f890ea742f851e8ee5fd42308331332a Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 10:32:24 +0200 Subject: [PATCH 65/75] perf(model): remove unused variable --- src/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.py b/src/model.py index bfd85ad..93db278 100644 --- a/src/model.py +++ b/src/model.py @@ -164,7 +164,7 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size agent_dim=False) for i in range(depth)]) self.decoder = torch.nn.Sequential(torch.nn.Conv1d(hidden_size + agent_state_size, hidden_size, 1, bias=False), - *[DecoderBlock(hidden_size) for i in range(decoder_depth)], + *[DecoderBlock(hidden_size, 0) for _ in range(decoder_depth)], torch.nn.InstanceNorm1d(hidden_size, affine=True), torch.nn.Conv1d(hidden_size, action_size, 1)) self.softmax = softmax From e0da5f2b89953589235e4b4e43dd81a89f10ee58 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 10:32:44 +0200 Subject: [PATCH 66/75] perf(agent): use np.where instead of list-list-list-comprehension --- src/agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/agent.py b/src/agent.py index 8dea564..9bef766 100644 --- a/src/agent.py +++ b/src/agent.py @@ -105,10 +105,10 @@ def step(self, state, action, agent_done, collision, step_reward=0, collision_re if len(self.stack) >= UPDATE_EVERY: action = torch.tensor(self.stack[1]).flatten(0, 1).to(device) - reward = torch.tensor([[[1 if ad - else (collision_reward if c else step_reward) for ad, c in zip(ad_batch, c_batch)] - for ad_batch, c_batch in zip(ad_step, c_step)] - for ad_step, c_step in zip(self.stack[2], self.stack[3])]).flatten(0, 1).to(device) + agent_done = np.array(self.stack[2]) + collision = np.array(self.stack[3]) + reward = np.where(agent_done, 1, np.where(collision, collision_reward, step_reward)) + reward = torch.tensor(reward, device=device, dtype=torch.float).flatten(0, 1) state = self.stack[0] if isinstance(state[0], tuple): state = zip(*state) From dfe0a183146cec090551e35df3a91e12d90e0edc Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 11:40:51 +0200 Subject: [PATCH 67/75] perf(agent): remove dead constants --- src/agent.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/agent.py b/src/agent.py index 9bef766..50396d2 100644 --- a/src/agent.py +++ b/src/agent.py @@ -11,14 +11,10 @@ from model import QNetwork, ConvNetwork, init, GlobalStateNetwork import os -BUFFER_SIZE = 500_000 BATCH_SIZE = 256 -GAMMA = 0.998 -TAU = 1e-3 CLIP_FACTOR = 0.2 LR = 1e-4 UPDATE_EVERY = 1 -DOUBLE_DQN = False CUDA = True device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") From 7e3a92b644ed7e6987394e1437610c6966ed5a61 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 11:41:13 +0200 Subject: [PATCH 68/75] perf(model): f(rail) at every step --- src/model.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/model.py b/src/model.py index 93db278..77f62af 100644 --- a/src/model.py +++ b/src/model.py @@ -128,18 +128,23 @@ def reset_cache(self): class DecoderBlock(torch.nn.Module): - def __init__(self, features, message_box=None): + def __init__(self, features, message_box=None, init_norm=True): super(DecoderBlock, self).__init__() - self.norm = torch.nn.InstanceNorm1d(features, affine=True) + self.init_norm = init_norm + self.norm = torch.nn.InstanceNorm1d(features, affine=True) if init_norm else nothing self.conv = torch.nn.Conv1d(features, features, 1, bias=False) self.message_box = int(features ** 0.5) if message_box is None else message_box def forward(self, fn_input: torch.Tensor) -> torch.Tensor: - out = self.norm(fn_input) - out = mish(out) + if self.init_norm: + out = self.norm(fn_input) + out = mish(out) + else: + out = fn_input out = self.conv(out) if self.message_box > 0: - out[:, :self.message_box] = out[:, :self.message_box].mean(-1, keepdim=True).expand(-1, -1, fn_input.size(-1)) + out[:, :self.message_box] = out[:, :self.message_box].mean(-1, keepdim=True).expand(-1, -1, + fn_input.size(-1)) return out + fn_input @@ -156,15 +161,15 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size agent_state_size = 2 * 13 self.net = torch.nn.Sequential(*[BasicBlock(global_state_size if not i else hidden_size, - hidden_size, + hidden_size - agent_state_size * (i == depth - 1), 2, init_norm=bool(i), message_box=0, double=False, agent_dim=False) for i in range(depth)]) - self.decoder = torch.nn.Sequential(torch.nn.Conv1d(hidden_size + agent_state_size, hidden_size, 1, bias=False), - *[DecoderBlock(hidden_size, 0) for _ in range(decoder_depth)], + self.mid_norm = torch.nn.InstanceNorm1d(hidden_size - agent_state_size, affine=True) + self.decoder = torch.nn.Sequential(*[DecoderBlock(hidden_size, 0, bool(i)) for i in range(decoder_depth)], torch.nn.InstanceNorm1d(hidden_size, affine=True), torch.nn.Conv1d(hidden_size, action_size, 1)) self.softmax = softmax @@ -177,12 +182,11 @@ def reset_cache(self): self.encoding_cache = self.base_zero def forward(self, state, rail) -> torch.Tensor: - if torch.equal(self.encoding_cache, self.base_zero): - inp = self.net(rail) - inp = inp.mean((2, 3), keepdim=True).squeeze(-1) - self.encoding_cache = inp - else: - inp = self.encoding_cache + inp = self.net(rail) + inp = inp.mean((2, 3), keepdim=True).squeeze(-1) + inp = mish(inp) + inp = self.mid_norm(inp) + self.encoding_cache = inp inp = torch.cat([inp.expand(-1, -1, state.size(-1)), state], 1) out = self.decoder(inp) return torch.nn.functional.softmax(out, 1) if self.softmax else out From 5962a31fa0b423913e1b8c51a5f17f0f112eaa77 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 11:41:36 +0200 Subject: [PATCH 69/75] feat(railway_utils): adapt new file schema --- src/railway_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/railway_utils.py b/src/railway_utils.py index 56bc8fe..1cf0216 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -94,13 +94,11 @@ def __call__(self, rail, _, hints, *args, **kwargs): def load_precomputed_railways(project_root, start_index, big=False): prefix = os.path.join(project_root, 'railroads') if big: - suffix = f'_35_6.pkl' + suffix = f'_35_6.pkl' # base=1.1 else: - suffix = f'_50.pkl' + suffix = f'_35_4.pkl' # base=1.04 rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) - if not big: - sched, rail = rail, sched print(f"Working on {len(rail)} tracks") return rail, sched From 110fbc8beb6e1ba3eabfa2bb5dfde0bf6791d642 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 23:28:04 +0200 Subject: [PATCH 70/75] perf(model): add attention --- src/model.py | 53 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/src/model.py b/src/model.py index 77f62af..273dbd3 100644 --- a/src/model.py +++ b/src/model.py @@ -128,12 +128,13 @@ def reset_cache(self): class DecoderBlock(torch.nn.Module): - def __init__(self, features, message_box=None, init_norm=True): + def __init__(self, in_features, out_features, message_box=None, init_norm=True): super(DecoderBlock, self).__init__() self.init_norm = init_norm - self.norm = torch.nn.InstanceNorm1d(features, affine=True) if init_norm else nothing - self.conv = torch.nn.Conv1d(features, features, 1, bias=False) - self.message_box = int(features ** 0.5) if message_box is None else message_box + self.residual = in_features == out_features + self.norm = torch.nn.InstanceNorm1d(in_features, affine=True) if init_norm else nothing + self.conv = torch.nn.Conv1d(in_features, out_features, 1, bias=False) + self.message_box = int(out_features ** 0.5) if message_box is None else message_box def forward(self, fn_input: torch.Tensor) -> torch.Tensor: if self.init_norm: @@ -145,12 +146,23 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: if self.message_box > 0: out[:, :self.message_box] = out[:, :self.message_box].mean(-1, keepdim=True).expand(-1, -1, fn_input.size(-1)) - return out + fn_input + if self.residual: + return out + fn_input + return out + + +@torch.jit.script +def attention(tensor: torch.Tensor): + query, key, value = tensor.chunk(3, 1) + query = query.transpose(1, 2) # B, F, S -> B, S, F + key = torch.bmm(query, key).softmax(1) + value = torch.bmm(value, key) + return value class GlobalStateNetwork(torch.nn.Module): def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, - decoder_depth=1, cat=True, debug=True, softmax=False): + decoder_depth=1, cat=True, debug=True, softmax=False, memory_size=4): super(GlobalStateNetwork, self).__init__() _ = state_size _ = kernel_size @@ -160,35 +172,38 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size global_state_size = 2 * 16 agent_state_size = 2 * 13 + self.memory_size = memory_size + self.net = torch.nn.Sequential(*[BasicBlock(global_state_size if not i else hidden_size, - hidden_size - agent_state_size * (i == depth - 1), + hidden_size + (i == depth - 1) * hidden_size * 2, 2, init_norm=bool(i), message_box=0, double=False, agent_dim=False) for i in range(depth)]) - self.mid_norm = torch.nn.InstanceNorm1d(hidden_size - agent_state_size, affine=True) - self.decoder = torch.nn.Sequential(*[DecoderBlock(hidden_size, 0, bool(i)) for i in range(decoder_depth)], - torch.nn.InstanceNorm1d(hidden_size, affine=True), - torch.nn.Conv1d(hidden_size, action_size, 1)) + self.decoder_input = DecoderBlock(agent_state_size, 3 * hidden_size, 0, False) + self.decoder = torch.nn.ModuleList([DecoderBlock(hidden_size, 3 * hidden_size, 0, True) + for _ in range(1, decoder_depth)]) + self.end_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) + self.end_conv = torch.nn.Conv1d(hidden_size, action_size, 1) self.softmax = softmax - self.register_buffer("base_zero", torch.zeros(1)) - self.encoding_cache = self.base_zero + self.memory_tensor = torch.nn.Parameter(torch.randn(1, 3 * hidden_size, memory_size)) @torch.jit.export def reset_cache(self): - self.encoding_cache = self.base_zero + pass def forward(self, state, rail) -> torch.Tensor: inp = self.net(rail) inp = inp.mean((2, 3), keepdim=True).squeeze(-1) - inp = mish(inp) - inp = self.mid_norm(inp) - self.encoding_cache = inp - inp = torch.cat([inp.expand(-1, -1, state.size(-1)), state], 1) - out = self.decoder(inp) + state = torch.cat([self.decoder_input(state) + inp, self.memory_tensor.expand(inp.size(0), -1, -1)], 2) + state = attention(state) + for block in self.decoder: + state = attention(block(state)) + state = state[:, :, :-self.memory_size] + out = self.end_conv(self.end_norm(state)) return torch.nn.functional.softmax(out, 1) if self.softmax else out From eb1b7b428e054e1693d341c2d15e6516ee27ab3d Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 25 Jul 2020 23:30:15 +0200 Subject: [PATCH 71/75] fix(agent): step every n steps, not every n data types --- src/agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/agent.py b/src/agent.py index 50396d2..442e016 100644 --- a/src/agent.py +++ b/src/agent.py @@ -11,10 +11,10 @@ from model import QNetwork, ConvNetwork, init, GlobalStateNetwork import os -BATCH_SIZE = 256 +BATCH_SIZE = 4 CLIP_FACTOR = 0.2 -LR = 1e-4 -UPDATE_EVERY = 1 +LR = 1e-5 +UPDATE_EVERY = 32 CUDA = True device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") @@ -99,7 +99,7 @@ def step(self, state, action, agent_done, collision, step_reward=0, collision_re or not k.startswith('_')] for a in agent_done]) self.stack[3].append(collision) - if len(self.stack) >= UPDATE_EVERY: + if len(self.stack[0]) >= UPDATE_EVERY: action = torch.tensor(self.stack[1]).flatten(0, 1).to(device) agent_done = np.array(self.stack[2]) collision = np.array(self.stack[3]) From 32dc07cce4af44893ee01fada8d5eea1aedbe14c Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 26 Jul 2020 20:58:12 +0200 Subject: [PATCH 72/75] style(train): improve maintainability --- src/train.py | 108 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 63 insertions(+), 45 deletions(-) diff --git a/src/train.py b/src/train.py index 1b95f0d..0f2b0fb 100644 --- a/src/train.py +++ b/src/train.py @@ -3,6 +3,7 @@ import time from itertools import zip_longest from pathlib import Path + import numpy as np import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters @@ -19,14 +20,14 @@ try: from .rail_env import RailEnv - from .agent import Agent as DQN_Agent, device, BATCH_SIZE + from .agent import Agent as DQN_Agent, device, BATCH_SIZE, UPDATE_EVERY from .normalize_output_data import wrap from .observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv, \ GlobalStateObs from .railway_utils import load_precomputed_railways, create_random_railways except: from rail_env import RailEnv - from agent import Agent as DQN_Agent, device, BATCH_SIZE + from agent import Agent as DQN_Agent, device, BATCH_SIZE, UPDATE_EVERY from normalize_output_data import wrap from observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv, \ GlobalStateObs @@ -130,7 +131,7 @@ def is_collision(a, i): return any(own_agent.position == agent.position for agent_id, agent in enumerate(environments[i].agents) if agent_id != a) -else: #model_type == 0 +else: # model_type == 0 def is_collision(a, i): if obs[i][a] is None: return False is_junction = not isinstance(obs[i][a].childs['L'], float) or not isinstance(obs[i][a].childs['R'], float) @@ -164,9 +165,40 @@ def normalize(observation, target_tensor): normalize_observation(observation, flags.tree_depth, target_tensor, 0) wrap(target_tensor) + def as_tensor(array_list): return torch.as_tensor(np.stack(array_list, 0), dtype=torch.float, device=device) + +def make_tensor(current, old): + if model_type == 1: + tensor = (torch.cat([old[0], current[0]], -1),) + elif model_type == 0: + tensor = (torch.cat([old[0].flatten(1, 2), current[0].flatten(1, 2)], 1),) + else: # model_type == 2 + tensor = (torch.cat((old[0], current[0]), 1), torch.cat((old[1], current[1]), -1)) + tensor[1].transpose_(1, -1) + tensor = tuple(t.to(device) for t in tensor) + return tensor + + +def clone(tensor_tuple): + return tuple(t.clone().detach().requires_grad_(t.requires_grad) for t in tensor_tuple) + + +def get_observation_tensor(observation, prev_tensor=None): + buffer = None if prev_tensor is None else clone(prev_tensor) + if model_type == 1: + obs_tensor = as_tensor(observation) + elif model_type == 0: + obs_tensor = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) + normalize(observation, obs_tensor) + else: # model_type == 2 + rail, obs = zip(*observation) + obs_tensor = as_tensor(rail), as_tensor(obs) + return obs_tensor, buffer + + episode = 0 POOL = multiprocessing.Pool() # env_renderer = None @@ -177,45 +209,30 @@ def as_tensor(array_list): # Main training loop episode = start +running_stats = {'score': 0, 'steps': 0, 'collisions': 0, 'done': 0, 'finished': 0} +batch_start = time.time() while True: episode += 1 agent.reset() obs, info = zip(*[env.reset() for env in environments]) # env_renderer = RenderTool(environments[0], gl="PILSVG", screen_width=1000, screen_height=1000, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) # env_renderer.reset() - episode_start = time.time() score, collision = 0, False agent_count = len(environments[0].agents) - if model_type == 1: - agent_obs = as_tensor(obs) - agent_obs_buffer = agent_obs.clone() - elif model_type == 0: - agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) - normalize(obs, agent_obs) - agent_obs_buffer = agent_obs.clone() - else: # model_type == 2 - rail, obs = zip(*obs) - agent_obs = as_tensor(rail), as_tensor(obs) - agent_obs_buffer = agent_obs[0].clone(), agent_obs[1].clone() agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] - + agent_obs, _ = get_observation_tensor(obs) + agent_obs_buffer = clone(agent_obs) + input_tensor = make_tensor(agent_obs, agent_obs_buffer) # Run an episode city_count = (env.width * env.height) // 300 max_steps = int(8 * (env.width + env.height + agent_count / city_count)) - 10 # -10 = have some distance to the "real" max steps done = [[False]] _done = [[False]] + step = 0 for step in range(max_steps): done = _done - if model_type == 1: - input_tensor = torch.cat([agent_obs_buffer, agent_obs], -1) - elif model_type == 0: - input_tensor = torch.cat([agent_obs_buffer.flatten(1, 2), agent_obs.flatten(1, 2)], 1) - else: # model_type == 2 - input_tensor = (torch.cat((agent_obs_buffer[0], agent_obs[0]), 1), - torch.cat((agent_obs_buffer[1], agent_obs[1]), -1)) - input_tensor[1].transpose_(1, -1) ret_action = agent.multi_act(input_tensor) action_dict = [dict(enumerate(act_list)) for act_list in ret_action] @@ -228,29 +245,22 @@ def as_tensor(array_list): all_done = (step == (max_steps - 1)) or all(d['__all__'] for d in _done) collision = [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)] # Update replay buffer and train agent + agent_obs, agent_obs_buffer = get_observation_tensor(obs, agent_obs) + next_input = make_tensor(agent_obs, agent_obs_buffer) + if flags.train: agent.step(input_tensor, agent_action_buffer, _done, collision, + next_input, flags.step_reward, flags.collision_reward) agent_action_buffer = [[act[i] for i in range(agent_count)] for act in action_dict] if all_done: break - - if model_type == 1: - agent_obs_buffer = agent_obs.clone() - agent_obs = as_tensor(obs) - elif model_type == 0: - agent_obs_buffer = agent_obs.clone() - agent_obs = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) - normalize(obs, agent_obs) - else: # model_type == 2 - rail, obs = zip(*obs) - agent_obs_buffer = agent_obs[0].clone(), agent_obs[1].clone() - agent_obs = as_tensor(rail), as_tensor(obs) + input_tensor = next_input # Render # if flags.render_interval and episode % flags.render_interval == 0: @@ -260,15 +270,23 @@ def as_tensor(array_list): # render() # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break + running_stats['score'] += score / max_steps + running_stats['steps'] += step + running_stats['collisions'] += sum(i for c in collision for i in c) / agent_count + running_stats['done'] += sum(d[i] for d in done for i in range(agent_count)) / agent_count + running_stats['finished'] += sum(d["__all__"] for d in done) + + if episode % UPDATE_EVERY == 0: + running_stats = {k: v / UPDATE_EVERY for k, v in running_stats.items()} + print(f'\rBatch{episode:>3} - Episode{BATCH_SIZE * episode:>5} - Agents:{agent_count:>3}' + f' | Score: {running_stats["score"]:.4f}' + f' | Steps: {running_stats["steps"]:4.0f}' + f' | Collisions: {100 * running_stats["collisions"] / BATCH_SIZE:6.2f}%' + f' | Done: {100 * running_stats["done"] / BATCH_SIZE:6.2f}%' + f' | Finished: {100 * running_stats["finished"] / BATCH_SIZE:6.2f}%' + f' | Took: {time.time() - batch_start:5.0f}s') + running_stats = {k: 0 for k in running_stats.keys()} + batch_start = time.time() - print(f'\rBatch{episode:>3} - Episode{BATCH_SIZE * episode:>5} - Agents:{agent_count:>3}' - f' | Score: {score / max_steps:.4f}' - f' | Steps: {step:4.0f}' - f' | Collisions: {100 * sum(i for c in collision for i in c) / (BATCH_SIZE * agent_count):6.2f}%' - f' | Done: {100 * sum(d[i] for d in done for i in range(agent_count)) / (BATCH_SIZE * agent_count):6.2f}%' - f' | Finished: {100 * sum(d["__all__"] for d in done) / BATCH_SIZE:6.2f}%' - f' | Took: {time.time() - episode_start:5.0f}s', end='') - - print("") # if flags.train: # agent.save(project_root / 'checkpoints', episode) From 899e03fb69493aa80f904c65a67ae7be28f692a3 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 26 Jul 2020 20:58:22 +0200 Subject: [PATCH 73/75] feat(agent): readd dqn --- src/agent.py | 95 +++++++++++++++++++++++++++++----------------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/src/agent.py b/src/agent.py index 442e016..6a92823 100644 --- a/src/agent.py +++ b/src/agent.py @@ -11,11 +11,14 @@ from model import QNetwork, ConvNetwork, init, GlobalStateNetwork import os -BATCH_SIZE = 4 +BATCH_SIZE = 256 CLIP_FACTOR = 0.2 -LR = 1e-5 -UPDATE_EVERY = 32 +LR = 1e-4 +UPDATE_EVERY = 1 CUDA = True +MINI_BACKWARD = False +DQN = True +DQN_PARAMS = {'tau': 1e-3, 'gamma': 0.998} device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") @@ -70,71 +73,79 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) # Replay memory - self.stack = [[] for _ in range(4)] + self.stack = [[] for _ in range(6)] self.t_step = 0 + self.idx = 1 def reset(self): self.policy.reset_cache() self.old_policy.reset_cache() def multi_act(self, state): - if isinstance(state, tuple): - state = tuple(s.to(device) for s in state) - elif isinstance(state, torch.Tensor): - state = (state.to(device),) self.policy.eval() with torch.no_grad(): action_values = self.policy(*state) - - # Epsilon-greedy action selection return action_values.argmax(1).detach().cpu().numpy() - # Record the results of the agent's action and update the model - - def step(self, state, action, agent_done, collision, step_reward=0, collision_reward=-2): + def step(self, state, action, agent_done, collision, next_state, step_reward=0, collision_reward=-2): + agent_count = len(agent_done[0])-1 self.stack[0].append(state) self.stack[1].append(action) - self.stack[2].append([[v for k, v in a.items() - if not hasattr(k, 'startswith') - or not k.startswith('_')] for a in agent_done]) + self.stack[2].append([[done[idx] for idx in range(agent_count)] for done in agent_done]) self.stack[3].append(collision) + self.stack[4].append(next_state) - if len(self.stack[0]) >= UPDATE_EVERY: + if MINI_BACKWARD or len(self.stack[0]) >= UPDATE_EVERY: action = torch.tensor(self.stack[1]).flatten(0, 1).to(device) agent_done = np.array(self.stack[2]) collision = np.array(self.stack[3]) reward = np.where(agent_done, 1, np.where(collision, collision_reward, step_reward)) reward = torch.tensor(reward, device=device, dtype=torch.float).flatten(0, 1) - state = self.stack[0] - if isinstance(state[0], tuple): - state = zip(*state) - state = tuple(torch.cat(st, 0).to(device) for st in state) - elif isinstance(state[0], torch.Tensor): - state = (torch.cat(state, 0).to(device),) - self.stack = [[] for _ in range(4)] - self.learn(state, action, reward) - - def learn(self, states, actions, rewards): + state = tuple(torch.cat(st, 0) for st in zip(*self.stack[0])) + next_state = tuple(torch.cat(st, 0) for st in zip(*self.stack[4])) + agent_done = torch.as_tensor(agent_done, device=device, dtype=torch.int8) + self.stack = [[] for _ in range(6)] + self.learn(state, action, reward, next_state, agent_done) + + def learn(self, states, actions, rewards, next_states, done): + if MINI_BACKWARD: + self.idx = (self.idx + 1) % UPDATE_EVERY + self.policy.train() actions.unsqueeze_(1) - with torch.no_grad(): - states_clone = tuple(st.clone() for st in states) - for st in states: - st.requires_grad_(False) - old_responsible_outputs = self.old_policy(*states_clone).gather(1, actions) - old_responsible_outputs.detach_() - responsible_outputs = self.policy(*states).gather(1, actions) - ratio = responsible_outputs / (old_responsible_outputs + 1e-5) - ratio.squeeze_(1) - clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = -torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).max() - self.old_policy.load_state_dict(self.policy.state_dict()) - self.optimizer.zero_grad() + if DQN: + expected = self.policy(*states).gather(1, actions) + best_action = self.policy(*next_states).argmax(1) + targets_next = self.old_policy(*next_states).gather(1, best_action.unsqueeze(1)) + targets = rewards + DQN_PARAMS['gamma'] * targets_next * (1 - done) + loss = (expected - targets).square().max(0)[0].mean() + else: + with torch.no_grad(): + states_clone = tuple(st.clone().detach().requires_grad_(False) for st in states) + old_responsible_outputs = self.old_policy(*states_clone).gather(1, actions) + old_responsible_outputs.detach_() + responsible_outputs = self.policy(*states).gather(1, actions) + ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + ratio.squeeze_(1) + clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) + loss = torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).max(0)[0].mean().neg() + + if MINI_BACKWARD: + loss = loss / UPDATE_EVERY + loss.backward() - self.optimizer.step() - # Checkpointing methods + if not MINI_BACKWARD or self.idx == 0: + if not DQN: + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer.step() + self.optimizer.zero_grad() + if DQN: + for target_param, local_param in zip(self.old_policy.parameters(), self.policy.parameters()): + target_param.data.copy_(DQN_PARAMS['tau'] * local_param.data + + (1.0 - DQN_PARAMS['tau']) * target_param.data) + def save(self, path, *data): torch.save(self.policy.state_dict(), path / 'dqn/model_checkpoint.local') From 4a6077fc350fcb1b6d81a001b8012ac53d30fa07 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 26 Jul 2020 23:27:33 +0200 Subject: [PATCH 74/75] fix(agent): readd replay buffer, loss functions --- src/agent.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/agent.py b/src/agent.py index 6a92823..ba14565 100644 --- a/src/agent.py +++ b/src/agent.py @@ -11,14 +11,15 @@ from model import QNetwork, ConvNetwork, init, GlobalStateNetwork import os -BATCH_SIZE = 256 +BATCH_SIZE = 64 CLIP_FACTOR = 0.2 LR = 1e-4 UPDATE_EVERY = 1 CUDA = True MINI_BACKWARD = False -DQN = True +DQN = False DQN_PARAMS = {'tau': 1e-3, 'gamma': 0.998} +EPOCHS = 16 device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") @@ -76,6 +77,7 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s self.stack = [[] for _ in range(6)] self.t_step = 0 self.idx = 1 + self.tensor_stack = [] def reset(self): self.policy.reset_cache() @@ -88,7 +90,7 @@ def multi_act(self, state): return action_values.argmax(1).detach().cpu().numpy() def step(self, state, action, agent_done, collision, next_state, step_reward=0, collision_reward=-2): - agent_count = len(agent_done[0])-1 + agent_count = len(agent_done[0]) - 1 self.stack[0].append(state) self.stack[1].append(action) self.stack[2].append([[done[idx] for idx in range(agent_count)] for done in agent_done]) @@ -96,40 +98,43 @@ def step(self, state, action, agent_done, collision, next_state, step_reward=0, self.stack[4].append(next_state) if MINI_BACKWARD or len(self.stack[0]) >= UPDATE_EVERY: - action = torch.tensor(self.stack[1]).flatten(0, 1).to(device) + action = torch.tensor(self.stack[1]).flatten(0, 1).to(device).unsqueeze_(1) agent_done = np.array(self.stack[2]) collision = np.array(self.stack[3]) reward = np.where(agent_done, 1, np.where(collision, collision_reward, step_reward)) - reward = torch.tensor(reward, device=device, dtype=torch.float).flatten(0, 1) + reward = torch.tensor(reward, device=device, dtype=torch.float).flatten(0, 1).unsqueeze_(1) state = tuple(torch.cat(st, 0) for st in zip(*self.stack[0])) next_state = tuple(torch.cat(st, 0) for st in zip(*self.stack[4])) - agent_done = torch.as_tensor(agent_done, device=device, dtype=torch.int8) + agent_done = torch.as_tensor(agent_done, device=device, dtype=torch.int8).flatten(0, 1).unsqueeze_(1) self.stack = [[] for _ in range(6)] - self.learn(state, action, reward, next_state, agent_done) + self.tensor_stack.append((state, action, reward, next_state, agent_done)) + if len(self.tensor_stack) >= EPOCHS: + tensor_stack = (torch.cat(t, 0) if isinstance(t[0], torch.Tensor) + else tuple(torch.cat(sub_t, 0) for sub_t in zip(*t)) + for t in zip(*self.tensor_stack)) + del self.tensor_stack[0] + self.learn(*tensor_stack) def learn(self, states, actions, rewards, next_states, done): if MINI_BACKWARD: self.idx = (self.idx + 1) % UPDATE_EVERY self.policy.train() - actions.unsqueeze_(1) if DQN: expected = self.policy(*states).gather(1, actions) best_action = self.policy(*next_states).argmax(1) targets_next = self.old_policy(*next_states).gather(1, best_action.unsqueeze(1)) targets = rewards + DQN_PARAMS['gamma'] * targets_next * (1 - done) - loss = (expected - targets).square().max(0)[0].mean() + loss = (expected - targets).square().max(0)[0].sum() else: with torch.no_grad(): states_clone = tuple(st.clone().detach().requires_grad_(False) for st in states) - old_responsible_outputs = self.old_policy(*states_clone).gather(1, actions) - old_responsible_outputs.detach_() + old_responsible_outputs = self.old_policy(*states_clone).gather(1, actions).detach_() responsible_outputs = self.policy(*states).gather(1, actions) ratio = responsible_outputs / (old_responsible_outputs + 1e-5) - ratio.squeeze_(1) clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = torch.min(ratio * rewards, clamped_ratio * rewards).sum(-1).max(0)[0].mean().neg() + loss = torch.min(ratio * rewards, clamped_ratio * rewards).max(0)[0].sum().neg() if MINI_BACKWARD: loss = loss / UPDATE_EVERY @@ -146,7 +151,6 @@ def learn(self, states, actions, rewards, next_states, done): target_param.data.copy_(DQN_PARAMS['tau'] * local_param.data + (1.0 - DQN_PARAMS['tau']) * target_param.data) - def save(self, path, *data): torch.save(self.policy.state_dict(), path / 'dqn/model_checkpoint.local') torch.save(self.old_policy.state_dict(), path / 'dqn/model_checkpoint.target') From bb9db64898efd63cfd8c87fcfa5325ab63e80898 Mon Sep 17 00:00:00 2001 From: Luke <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 27 Jul 2020 09:33:48 +0200 Subject: [PATCH 75/75] feat(agent): improve maintainability, add naf --- src/agent.py | 113 ++++++++++++++++++++++++++++++++++-------------- src/model.py | 119 +++++++++++++++++++++++++++++++++++---------------- src/train.py | 11 ++--- 3 files changed, 168 insertions(+), 75 deletions(-) diff --git a/src/agent.py b/src/agent.py index ba14565..0df4269 100644 --- a/src/agent.py +++ b/src/agent.py @@ -1,32 +1,49 @@ import math import pickle +import typing import numpy as np import torch from torch_optimizer import Yogi as Optimizer try: - from .model import QNetwork, ConvNetwork, init, GlobalStateNetwork + from .model import QNetwork, ConvNetwork, init, GlobalStateNetwork, TripleClassificationHead, NAFHead except: - from model import QNetwork, ConvNetwork, init, GlobalStateNetwork + from model import QNetwork, ConvNetwork, init, GlobalStateNetwork, TripleClassificationHead, NAFHead import os -BATCH_SIZE = 64 +BATCH_SIZE = 256 CLIP_FACTOR = 0.2 LR = 1e-4 UPDATE_EVERY = 1 CUDA = True MINI_BACKWARD = False -DQN = False -DQN_PARAMS = {'tau': 1e-3, 'gamma': 0.998} -EPOCHS = 16 +DQN_TAU = 1e-3 +EPOCHS = 1 device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") -class Agent: +@torch.jit.script +def aggregate(loss: torch.Tensor): + maximum, _ = loss.max(0) + return maximum.sum() + + +@torch.jit.script +def mse(in_x, in_y): + return aggregate((in_x - in_y).square()) + + +@torch.jit.script +def dqn_target(rewards, targets_next, done): + return rewards + 0.998 * targets_next * (1 - done) + + +class Agent(torch.nn.Module): def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, decoder_depth, - model_type=0, softmax=True, debug=True): + model_type=0, softmax=True, debug=True, loss_type='PPO'): + super(Agent, self).__init__() self.action_size = action_size # Q-Network @@ -36,21 +53,25 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s network = QNetwork else: # Global State network = GlobalStateNetwork + if loss_type in ('PPO', 'DQN'): + tail = TripleClassificationHead(hidden_factor, action_size) + else: + tail = NAFHead(hidden_factor, action_size) self.policy = network(state_size, - action_size, hidden_factor, model_depth, kernel_size, squeeze_heads, decoder_depth, + tail=tail, softmax=softmax).to(device) self.old_policy = network(state_size, - action_size, hidden_factor, model_depth, kernel_size, squeeze_heads, decoder_depth, + tail=tail, softmax=softmax, debug=False).to(device) if debug: @@ -78,16 +99,57 @@ def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_s self.t_step = 0 self.idx = 1 self.tensor_stack = [] + self._policy_update = loss_type in ("PPO",) + self._soft_update = loss_type in ("DQN", "NAF") + + self._action_index = torch.zeros(1) + self._value_index = torch.zeros(1) + 1 + self._triangular_index = torch.zeros(1) + 2 + + self.loss = getattr(self, f'_{loss_type.lower()}_loss') + + def _dqn_loss(self, states, actions, next_states, rewards, done): + actions = actions.argmax(1) + expected = self.policy(self._action_index, self._action_index, *states).gather(1, actions) + best_action = self.policy(self._action_index, self._action_index, next_states).argmax(1) + targets_next = self.old_policy(self._action_index, self._action_index, + *next_states).gather(1, best_action.unsqueeze(1)) + targets = dqn_target(rewards, targets_next, done) + loss = mse(expected, targets) + return loss + + def _naf_loss(self, states, actions, next_states, rewards, done): + targets_next = self.old_policy(self._value_index, self._action_index, next_states) + state_action_values = self.policy(self._triangular_index, actions, states) + targets = dqn_target(rewards, targets_next, done) + loss = mse(state_action_values, targets) + return loss + + def _ppo_loss(self, states, actions, next_states, rewards, done): + _ = next_states + _ = done + actions = actions.argmax(1) + states_clone = [st.clone().detach().requires_grad_(False) for st in states] + old_responsible_outputs = self.old_policy(self._action_index, self._action_index, + *states_clone).gather(1, actions).detach_() + responsible_outputs = self.policy(self._action_index, self._action_index, *states).gather(1, actions) + ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) + loss = aggregate(torch.min(ratio * rewards, clamped_ratio * rewards)).neg() + return loss def reset(self): self.policy.reset_cache() self.old_policy.reset_cache() - def multi_act(self, state): + def multi_act(self, state, argmax_only=True) -> typing.Union[typing.Tuple[np.ndarray, np.ndarray], np.ndarray]: self.policy.eval() with torch.no_grad(): - action_values = self.policy(*state) - return action_values.argmax(1).detach().cpu().numpy() + action_values = self.policy(self._action_index, self._action_index, *state).detach() + argmax = action_values.argmax(1).cpu().numpy() + if argmax_only: + return argmax + return action_values, argmax def step(self, state, action, agent_done, collision, next_state, step_reward=0, collision_reward=-2): agent_count = len(agent_done[0]) - 1 @@ -98,7 +160,7 @@ def step(self, state, action, agent_done, collision, next_state, step_reward=0, self.stack[4].append(next_state) if MINI_BACKWARD or len(self.stack[0]) >= UPDATE_EVERY: - action = torch.tensor(self.stack[1]).flatten(0, 1).to(device).unsqueeze_(1) + action = torch.cat(self.stack[1]).to(device).unsqueeze_(1) agent_done = np.array(self.stack[2]) collision = np.array(self.stack[3]) reward = np.where(agent_done, 1, np.where(collision, collision_reward, step_reward)) @@ -121,20 +183,7 @@ def learn(self, states, actions, rewards, next_states, done): self.policy.train() - if DQN: - expected = self.policy(*states).gather(1, actions) - best_action = self.policy(*next_states).argmax(1) - targets_next = self.old_policy(*next_states).gather(1, best_action.unsqueeze(1)) - targets = rewards + DQN_PARAMS['gamma'] * targets_next * (1 - done) - loss = (expected - targets).square().max(0)[0].sum() - else: - with torch.no_grad(): - states_clone = tuple(st.clone().detach().requires_grad_(False) for st in states) - old_responsible_outputs = self.old_policy(*states_clone).gather(1, actions).detach_() - responsible_outputs = self.policy(*states).gather(1, actions) - ratio = responsible_outputs / (old_responsible_outputs + 1e-5) - clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = torch.min(ratio * rewards, clamped_ratio * rewards).max(0)[0].sum().neg() + loss = self.loss(states, actions, next_states, rewards, done) if MINI_BACKWARD: loss = loss / UPDATE_EVERY @@ -142,14 +191,14 @@ def learn(self, states, actions, rewards, next_states, done): loss.backward() if not MINI_BACKWARD or self.idx == 0: - if not DQN: + if self._policy_update: self.old_policy.load_state_dict(self.policy.state_dict()) self.optimizer.step() self.optimizer.zero_grad() - if DQN: + if self._soft_update: for target_param, local_param in zip(self.old_policy.parameters(), self.policy.parameters()): - target_param.data.copy_(DQN_PARAMS['tau'] * local_param.data + - (1.0 - DQN_PARAMS['tau']) * target_param.data) + target_param.data.copy_(DQN_TAU * local_param.data + + (1.0 - DQN_TAU) * target_param.data) def save(self, path, *data): torch.save(self.policy.state_dict(), path / 'dqn/model_checkpoint.local') diff --git a/src/model.py b/src/model.py index 273dbd3..dfb0714 100644 --- a/src/model.py +++ b/src/model.py @@ -99,10 +99,68 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: return out -class ConvNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, - decoder_depth=1, cat=True, debug=True, softmax=False): - super(ConvNetwork, self).__init__() +class NAFHead(torch.nn.Module): + def __init__(self, hidden_size, action_size): + super(NAFHead, self).__init__() + self.action_size = action_size + self.action = torch.nn.Conv1d(hidden_size, action_size, 1) + self.value = torch.nn.Conv1d(hidden_size, 1, 1) + self.triangular = torch.nn.Conv1d(hidden_size, action_size ** 2, 1) + self.diagonal_mask = torch.ones((action_size, action_size)).diag().diag().unsqueeze_(0) + + def forward(self, fn_input, idx, prev_action): + if idx == 0: + return self.action(fn_input) + if idx == 1: + return self.value(fn_input) + + actions, value = self.action(fn_input), self.value(fn_input) + + batch = fn_input.size(0) + triangular = self.triangular(fn_input).view(batch, self.action_size, self.action_size, -1) + triangular = triangular.tril() + triangular.mul(self.diagonal_mask).exp() + matrix = torch.bmm(triangular, triangular.transpose(1, 2)) + + action_difference = (prev_action - actions).unsqueeze(2) + advantage = torch.bmm(torch.bmm(action_difference.transpose(1, 2), matrix), + action_difference)[:, :, 0].div(2).neg() + + return advantage + value + + +class TripleClassificationHead(torch.nn.Module): + def __init__(self, hidden_size, action_size): + super(TripleClassificationHead, self).__init__() + self.linear = torch.nn.Conv1d(hidden_size, action_size, 1) + + def forward(self, fn_input, idx, prev_action): + _ = idx + _ = prev_action + return self.linear(fn_input) + + +class TailModel(torch.nn.Module): + def __init__(self, tail): + super(TailModel, self).__init__() + self.tail = nothing if tail is None else tail + self.no_tensor = torch.zeros(1) + + def _backbone(self, fn_input: torch.Tensor, potential_fn_input: torch.Tensor) -> torch.Tensor: + raise UserWarning("Has to be implemented by child class") + + def forward(self, idx, prev_action, state, rail): + out = self._backbone(state, rail) + return self.tail(out, idx, prev_action) + + @torch.jit.export + def reset_cache(self): + pass + + +class ConvNetwork(TailModel): + def __init__(self, state_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=True, debug=True, softmax=False, tail=None): + super(ConvNetwork, self).__init__(tail) _ = state_size state_size = 2 * 21 self.net = torch.nn.ModuleList([BasicBlock(state_size if not i else hidden_size, @@ -111,20 +169,15 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size init_norm=bool(i)) for i in range(depth)]) self.init_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) - self.linear = torch.nn.Conv1d(hidden_size, action_size, 1) - self.softmax = softmax - def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + def _backbone(self, fn_input: torch.Tensor, trash: torch.Tensor) -> torch.Tensor: + _ = trash out = fn_input for module in self.net: out = module(out) out = out.mean((2, 3)) - out = self.linear(mish(self.init_norm(out))) - return torch.nn.functional.softmax(out, 1) if self.softmax else out - - @torch.jit.export - def reset_cache(self): - pass + out = mish(self.init_norm(out)) + return out class DecoderBlock(torch.nn.Module): @@ -160,10 +213,10 @@ def attention(tensor: torch.Tensor): return value -class GlobalStateNetwork(torch.nn.Module): - def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, - decoder_depth=1, cat=True, debug=True, softmax=False, memory_size=4): - super(GlobalStateNetwork, self).__init__() +class GlobalStateNetwork(TailModel): + def __init__(self, state_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=True, debug=True, softmax=False, memory_size=4, tail=None): + super(GlobalStateNetwork, self).__init__(tail) _ = state_size _ = kernel_size _ = squeeze_heads @@ -186,16 +239,10 @@ def __init__(self, state_size, action_size, hidden_size=15, depth=8, kernel_size self.decoder = torch.nn.ModuleList([DecoderBlock(hidden_size, 3 * hidden_size, 0, True) for _ in range(1, decoder_depth)]) self.end_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) - self.end_conv = torch.nn.Conv1d(hidden_size, action_size, 1) - self.softmax = softmax self.memory_tensor = torch.nn.Parameter(torch.randn(1, 3 * hidden_size, memory_size)) - @torch.jit.export - def reset_cache(self): - pass - - def forward(self, state, rail) -> torch.Tensor: + def _backbone(self, state: torch.Tensor, rail: torch.Tensor) -> torch.Tensor: inp = self.net(rail) inp = inp.mean((2, 3), keepdim=True).squeeze(-1) state = torch.cat([self.decoder_input(state) + inp, self.memory_tensor.expand(inp.size(0), -1, -1)], 2) @@ -203,8 +250,8 @@ def forward(self, state, rail) -> torch.Tensor: for block in self.decoder: state = attention(block(state)) state = state[:, :, :-self.memory_size] - out = self.end_conv(self.end_norm(state)) - return torch.nn.functional.softmax(out, 1) if self.softmax else out + out = self.end_norm(state) + return out def init(module: torch.nn.Module): @@ -232,10 +279,10 @@ def forward(self, fn_input: torch.Tensor) -> torch.Tensor: return fn_input + out -class QNetwork(torch.nn.Sequential): - def __init__(self, state_size, action_size, hidden_factor=16, depth=4, kernel_size=7, squeeze_heads=4, - decoder_depth=1, cat=False, debug=True): - super(QNetwork, self).__init__() +class QNetwork(TailModel): + def __init__(self, state_size, hidden_factor=16, depth=4, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=False, debug=True, tail=None): + super(QNetwork, self).__init__(tail) _ = depth _ = kernel_size _ = squeeze_heads @@ -245,12 +292,8 @@ def __init__(self, state_size, action_size, hidden_factor=16, depth=4, kernel_si self.model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 11 * hidden_factor, 1, groups=11, bias=False), Residual(11 * hidden_factor), torch.nn.BatchNorm1d(11 * hidden_factor), - Mish(), - torch.nn.Conv1d(11 * hidden_factor, action_size, 1)) - - @torch.jit.export - def reset_cache(self): - pass + Mish()) - def forward(self, *args): - return self.model(*args) + def _backbone(self, fn_input: torch.Tensor, trash: torch.Tensor) -> torch.Tensor: + _ = trash + return self.model(fn_input) diff --git a/src/train.py b/src/train.py index 0f2b0fb..f3218ff 100644 --- a/src/train.py +++ b/src/train.py @@ -172,9 +172,9 @@ def as_tensor(array_list): def make_tensor(current, old): if model_type == 1: - tensor = (torch.cat([old[0], current[0]], -1),) + tensor = (torch.cat([old[0], current[0]], -1), torch.zeros(1)) elif model_type == 0: - tensor = (torch.cat([old[0].flatten(1, 2), current[0].flatten(1, 2)], 1),) + tensor = (torch.cat([old[0].flatten(1, 2), current[0].flatten(1, 2)], 1), torch.zeros(1)) else: # model_type == 2 tensor = (torch.cat((old[0], current[0]), 1), torch.cat((old[1], current[1]), -1)) tensor[1].transpose_(1, -1) @@ -220,7 +220,8 @@ def get_observation_tensor(observation, prev_tensor=None): score, collision = 0, False agent_count = len(environments[0].agents) - agent_action_buffer = [[2] * agent_count for _ in range(BATCH_SIZE)] + agent_action_buffer = torch.zeros((BATCH_SIZE, 5, agent_count), device=device, dtype=torch.float) + agent_action_buffer[:, 2] += 1 agent_obs, _ = get_observation_tensor(obs) agent_obs_buffer = clone(agent_obs) input_tensor = make_tensor(agent_obs, agent_obs_buffer) @@ -234,7 +235,7 @@ def get_observation_tensor(observation, prev_tensor=None): for step in range(max_steps): done = _done - ret_action = agent.multi_act(input_tensor) + mdl_action, ret_action = agent.multi_act(input_tensor, False) action_dict = [dict(enumerate(act_list)) for act_list in ret_action] # Environment step @@ -256,7 +257,7 @@ def get_observation_tensor(observation, prev_tensor=None): next_input, flags.step_reward, flags.collision_reward) - agent_action_buffer = [[act[i] for i in range(agent_count)] for act in action_dict] + agent_action_buffer = mdl_action if all_done: break