Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
9fac1da
I'm seriously not gonna split this
ClashLuke Jul 8, 2020
4870b92
feat: add concatenation script for railways
ClashLuke Jul 8, 2020
2c4e1c9
feat: add slightly improved dqn agent
ClashLuke Jul 8, 2020
9d8a8fc
perf: use torch only (no numpy)
ClashLuke Jul 8, 2020
6669ed3
feat: add baseline model
ClashLuke Jul 8, 2020
2e83da3
feat: train all agents at the same time
ClashLuke Jul 8, 2020
9fde258
feat: add attention
ClashLuke Jul 8, 2020
6c6664d
feat: improve speed
ClashLuke Jul 9, 2020
e0ac945
feat: process one batch at a time instead of caching
ClashLuke Jul 9, 2020
29fa930
perf: use torch for of observation_utils.py
ClashLuke Jul 9, 2020
4b40ae8
perf: remove all numpy
ClashLuke Jul 9, 2020
bc29151
perf: use iterative solution instead of recursive one (~+20%)
ClashLuke Jul 9, 2020
e3e0920
perf: add multiprocessing (+30%)
ClashLuke Jul 9, 2020
c966052
perf: init removal of redundant vars
ClashLuke Jul 10, 2020
3bf53d4
Revert "perf: init removal of redundant vars"
ClashLuke Jul 10, 2020
3425712
perf: cythonize normalizer
ClashLuke Jul 10, 2020
6a85249
perf: cythonize tree observation
ClashLuke Jul 10, 2020
9491429
perf: reduce model size, use single qn
ClashLuke Jul 10, 2020
0418a67
perf: cythonize rail_env
ClashLuke Jul 10, 2020
a55ffaf
style: pycharm reforamt
ClashLuke Jul 10, 2020
4c805d3
perf: add typehints to railenv
ClashLuke Jul 10, 2020
e7b6a9f
fix: re-add support for unitialized transition
ClashLuke Jul 10, 2020
0ff41d2
style: move cat to correct folder
ClashLuke Jul 10, 2020
4b3b666
perf: use ppo
ClashLuke Jul 11, 2020
18e5597
feat: give ppo network previous state
ClashLuke Jul 11, 2020
af8952a
feat: railway_utils.py increase agent count over time
ClashLuke Jul 13, 2020
a904cf4
feat(train): add comfort functions to train.py
ClashLuke Jul 13, 2020
cd6207f
fix(model): use dropout instead of bernoulli, init properly
ClashLuke Jul 13, 2020
9fd919f
feat(train): add cli args for width/agents of env
ClashLuke Jul 13, 2020
d284c57
feat(rail-generator): cythonize
ClashLuke Jul 13, 2020
e72fe0d
style(train): remove unused commandline arguments
ClashLuke Jul 13, 2020
5537b95
feat(railway-utils): use custom generator
ClashLuke Jul 13, 2020
413d571
perf(generate-railways): undo cythonizing
ClashLuke Jul 13, 2020
a953f13
style: remove legacy code
ClashLuke Jul 13, 2020
1cbabce
fix(model): message box for agents (not features)
ClashLuke Jul 13, 2020
e9cf607
feat: add global, local env
ClashLuke Jul 17, 2020
59e385e
feat: re-enable cuda (after reducing input size)
ClashLuke Jul 17, 2020
36c6e4a
feat: add readme to keep folder structure
ClashLuke Jul 17, 2020
af2ec5c
fix: add garbage to gitignore
ClashLuke Jul 17, 2020
9c4e4b3
style: major cleanup
ClashLuke Jul 18, 2020
ea376d4
fix: remove NaN's
ClashLuke Jul 18, 2020
3830353
perf(agent): add JIT, remove unnecessary list comprehension
ClashLuke Jul 19, 2020
e83378e
perf(model): use instancenorm, add message box
ClashLuke Jul 19, 2020
4c35628
style: remove unused variables
ClashLuke Jul 19, 2020
bc1cb82
perf(observation): use int8
ClashLuke Jul 19, 2020
cdacc24
perf(rail-env): remove type enforcing
ClashLuke Jul 19, 2020
e0b244f
perf(train): add support for np.ndarray observation
ClashLuke Jul 19, 2020
ae8073c
feat: add finish rate
ClashLuke Jul 19, 2020
f50e92f
style(observation-utils): add return_array parameter (backwards compa…
ClashLuke Jul 19, 2020
dbee685
perf(train): mildly improve sum iterator
ClashLuke Jul 19, 2020
b0107f6
fix(train): use correct variables for running stats
ClashLuke Jul 19, 2020
55c72d5
style(model): remove unused deps
ClashLuke Jul 19, 2020
78936f9
perf(agent): first calculate divisor (+cleanup), then do big op
ClashLuke Jul 20, 2020
f695f65
perf(model): fix stride
ClashLuke Jul 20, 2020
325ae4e
fix(train): re-add support for model loading
ClashLuke Jul 20, 2020
3759e77
fix(cythonize): use gcc9 instead of 7 (9 doens't work with cuda10)
ClashLuke Jul 20, 2020
a11a470
perf(rail_env): remove rtol for 0 values
ClashLuke Jul 20, 2020
c2db1e5
feat(train): improve the interface by adding better outputs, _always_…
ClashLuke Jul 20, 2020
ae20951
feat(train): remove cross-batch mean calculation
ClashLuke Jul 21, 2020
696da85
feat: add global-state model (globalobs + agent-state)
ClashLuke Jul 24, 2020
c10860b
perf(agent): enforce sparse gradient (experiment)
ClashLuke Jul 24, 2020
c46b1c5
feat(interface): move depth-1
ClashLuke Jul 24, 2020
6ee0008
feat(interface): move depth-1
ClashLuke Jul 25, 2020
50503ea
perf(cython): add header instructions
ClashLuke Jul 25, 2020
7a69bbf
perf(model): remove unused variable
ClashLuke Jul 25, 2020
e0da5f2
perf(agent): use np.where instead of list-list-list-comprehension
ClashLuke Jul 25, 2020
dfe0a18
perf(agent): remove dead constants
ClashLuke Jul 25, 2020
7e3a92b
perf(model): f(rail) at every step
ClashLuke Jul 25, 2020
5962a31
feat(railway_utils): adapt new file schema
ClashLuke Jul 25, 2020
110fbc8
perf(model): add attention
ClashLuke Jul 25, 2020
eb1b7b4
fix(agent): step every n steps, not every n data types
ClashLuke Jul 25, 2020
32dc07c
style(train): improve maintainability
ClashLuke Jul 26, 2020
899e03f
feat(agent): readd dqn
ClashLuke Jul 26, 2020
4a6077f
fix(agent): readd replay buffer, loss functions
ClashLuke Jul 26, 2020
bb9db64
feat(agent): improve maintainability, add naf
ClashLuke Jul 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,24 @@ venv/
.DS_Store

*.pyc
/src/a.out
/index.html
/checkpoints/dqn/dqn0/loss.txt
/.idea/**
**model_checkpoint*
/.idea/modules.xml
/src/obs.so
/src/observation_utils.c
/src/observation_utils.h
/src/observation_utils.so
/.idea/other.xml
/.idea/inspectionProfiles/profiles_settings.xml
/.idea/inspectionProfiles/Project_Default.xml
/r/rail_networks_35x40x40.pkl
/r/rail_networks_sum.pkl
/r/schedules_35x40x40.pkl
/r/schedules_sum.pkl
/.idea/vcs.xml
*.c
*.so
*.out
77 changes: 57 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,70 @@
# flatland-training

This repo contains an optimized version of flatland-rl's `flatland.envs.observations.TreeObsForRailEnv`. Tree-based observations allow RL models to learn much more quickly than the global observations do, but flatland's built-in TreeObsForRailEnv is kind of slow, so I wrote a faster version! This repo also contains an optimized version of [https://gitlab.aicrowd.com/flatland/baselines/blob/master/utils/observation_utils.py](https://gitlab.aicrowd.com/flatland/baselines/blob/master/utils/observation_utils.py), which flattens and normalizes the tree observations into 1D numpy arrays that can be passed to a feed-forward network.
PyTorch solution for [flatland-2020](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/)

## Implementation

# Setup
## Create venv
`python3.7 -m venv venv`\
`source venv/bin/activate`
This repository contains three major modules.

Verify python version is correct with: `python -V`\
Should return `Python 3.7.something`
### Getting Started

## Install Requirements
`pip install -r requirements.txt`
#### Setup

Before following along here, please note that there is an `install.sh` script which executes all the commands here.\
First, create virtual environment using `python3.7 -m venv venv && source venv/bin/activate`.\
Then install the requirements with `python3 -m pip install -r requirements.txt`.\
It is recommended to verify that the installation was successful by first checking the python version and then attempting an import of all required non-standard packages.
```bash
$ python3 --version
Python 3.7.6
$ python3 -c "import torch, torch_optimizer, numpy, cython, flatland, gym, tqdm; print('Successfully imported packages')"
Successfully imported packages
```
Lastly perhaps the most crucial step. It requires `gcc-7`, as no other version works. On Debian/Ubuntu, it can be installed using the apt package manager by running `apt install gcc-7`.\
Once that's done, the python code can be compiled using cython. Compilation is done by first moving into the source folder and then executing cythonize.sh via `cd src && bash cythonize.sh`.

# Generate Railways
This script will precompute a bunch of railway maps to make training faster.\
#### Generate Environments

`python src/generate_railways.py`
For better training performance, one can optionally generate the environments used to train the network _before_ training it. This way training is much faster as training data doesn't have to be regenerated repeatedly but instead gets loaded once on startup.\
To start the generation of environments and their respective railways, use `python3 src/generate_railways.py --width 50`. The command will generate 50x50 grids of cities, rails and trains.

This will run for quite a long time, go get some tea...\
But also it's fine to stop it after it completes at least one round if you just want to test things out and make sure they run.\
If you don't care about the speedup, you can run `python src/train.py --load-railways=False` to generate railways on the fly during training instead.
#### Run Training

Finally, it's time to train the model. You can do so by running `python3 src/train.py`, which will train a basic cnn using the "local observation" method. ![https://flatland.aicrowd.com/getting-started/env.html](https://i.imgur.com/oo8EIYv.png)\
Not only global and tree observation, but also many model parameters are implemented as well. To find out more about them, add the `--help` flag.

# Run Training
`python src/train.py`
### Structure

This will begin training one or more agents in the flatland environment.\
This file has a lot of parameters that can be set to do different things.\
To see all the options, use the `--help` command line argument.
Currently the code is structured in huge monolithic files.

| Name | Description |
|----|----|
| agent.py | Reward and trainings algorithm, as well as some hyper parameters (such as batch size and learning rate)|
| generate_railways.py | Script to pre-compute and generate railways from the command line. See [#Generate Environments](#Generate-Environments)|
| model.py | PyTorch definition of tree-observation and local-observation models|
| observation_utils.pyx | Agent observation utilities called by environment to create training observation |
| rail_env.pyx | Cython-port of flatland-rl RailEnv |
| railway_utils.py | Utility script to handle creation of and iterators over railways |
| train.py | Core training loop |

### Future Work

The current implementation has many holes. One of them is the very poor performance received when controlling many (>10) agents at once.\
We tackle this issue from multiple sites at once. If you would like to participate in this team, open an issue, pull request or join us on [discord](https://discord.gg/mP72wbE).
Our current approaches are listed below:

* **Observation, Model**:
* Tree observation, graph neural networks
* Tree observation, fully-connected networks
* Tree observation, transformer
* Local observation, cnn
* Global observation, cnn
* **Teaching algorithm**:
* PPO
* (Double-) DQN
* **Misc. Freebies**:
* Epsilon-Greedy
* Multiprocessing
* Inter-agent communication

If you are working on one of these tasks or would like to do so, please open an issue or pull request to let others know about it. Once seen, it will be added to the main repository.
1 change: 1 addition & 0 deletions checkpoints/dqn/dqn0/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DQN checkpoints will be saved here
1 change: 0 additions & 1 deletion checkpoints/ppo/README.md

This file was deleted.

16 changes: 16 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env bash

function check_python {
python3 -c "import torch, torch_optimizer, numpy, cython, flatland, gym, tqdm; print('Successfully imported packages')" 2>/dev/null\
|| (python3.7 -m venv venv && source venv/bin/activate && python3 -m pip install -r requirements.txt) || check_python
}

check_python

echo "Checking for GCC-7"

gcc-7 --version > /dev/null || sudo apt install gcc-7

echo "Compiling source"
cd src && source cythonize.sh > /dev/null 2>/dev/null
cd ..
1 change: 0 additions & 1 deletion railroads/README.md

This file was deleted.

15 changes: 7 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
argparse==1.4.0
flatland-rl==2.2.1
numpy==1.18.5
torch==1.5.0
opencv-python==4.2.0.34
Pillow==7.1.2
tqdm==4.46.1
tensorboardX==2.0
torch
torch-optimizer
numpy
cython
flatland-rl
gym
tqdm
224 changes: 224 additions & 0 deletions src/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import math
import pickle
import typing

import numpy as np
import torch
from torch_optimizer import Yogi as Optimizer

try:
from .model import QNetwork, ConvNetwork, init, GlobalStateNetwork, TripleClassificationHead, NAFHead
except:
from model import QNetwork, ConvNetwork, init, GlobalStateNetwork, TripleClassificationHead, NAFHead
import os

BATCH_SIZE = 256
CLIP_FACTOR = 0.2
LR = 1e-4
UPDATE_EVERY = 1
CUDA = True
MINI_BACKWARD = False
DQN_TAU = 1e-3
EPOCHS = 1

device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu")


@torch.jit.script
def aggregate(loss: torch.Tensor):
maximum, _ = loss.max(0)
return maximum.sum()


@torch.jit.script
def mse(in_x, in_y):
return aggregate((in_x - in_y).square())


@torch.jit.script
def dqn_target(rewards, targets_next, done):
return rewards + 0.998 * targets_next * (1 - done)


class Agent(torch.nn.Module):
def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, decoder_depth,
model_type=0, softmax=True, debug=True, loss_type='PPO'):
super(Agent, self).__init__()
self.action_size = action_size

# Q-Network
if model_type == 1: # Global/Local
network = ConvNetwork
elif model_type == 0: # Tree
network = QNetwork
else: # Global State
network = GlobalStateNetwork
if loss_type in ('PPO', 'DQN'):
tail = TripleClassificationHead(hidden_factor, action_size)
else:
tail = NAFHead(hidden_factor, action_size)
self.policy = network(state_size,
hidden_factor,
model_depth,
kernel_size,
squeeze_heads,
decoder_depth,
tail=tail,
softmax=softmax).to(device)
self.old_policy = network(state_size,
hidden_factor,
model_depth,
kernel_size,
squeeze_heads,
decoder_depth,
tail=tail,
softmax=softmax,
debug=False).to(device)
if debug:
print(self.policy)

parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, self.policy.parameters()))
digits = int(math.log10(parameters))
number_string = " kMGTPEZY"[digits // 3]

print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}"
f"{number_string} parameters")
self.policy.apply(init)
try:
self.policy = torch.jit.script(self.policy)
self.old_policy = torch.jit.script(self.old_policy)
except:
import traceback
traceback.print_exc()
print("NO JIT")
self.old_policy.load_state_dict(self.policy.state_dict())
self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2)

# Replay memory
self.stack = [[] for _ in range(6)]
self.t_step = 0
self.idx = 1
self.tensor_stack = []
self._policy_update = loss_type in ("PPO",)
self._soft_update = loss_type in ("DQN", "NAF")

self._action_index = torch.zeros(1)
self._value_index = torch.zeros(1) + 1
self._triangular_index = torch.zeros(1) + 2

self.loss = getattr(self, f'_{loss_type.lower()}_loss')

def _dqn_loss(self, states, actions, next_states, rewards, done):
actions = actions.argmax(1)
expected = self.policy(self._action_index, self._action_index, *states).gather(1, actions)
best_action = self.policy(self._action_index, self._action_index, next_states).argmax(1)
targets_next = self.old_policy(self._action_index, self._action_index,
*next_states).gather(1, best_action.unsqueeze(1))
targets = dqn_target(rewards, targets_next, done)
loss = mse(expected, targets)
return loss

def _naf_loss(self, states, actions, next_states, rewards, done):
targets_next = self.old_policy(self._value_index, self._action_index, next_states)
state_action_values = self.policy(self._triangular_index, actions, states)
targets = dqn_target(rewards, targets_next, done)
loss = mse(state_action_values, targets)
return loss

def _ppo_loss(self, states, actions, next_states, rewards, done):
_ = next_states
_ = done
actions = actions.argmax(1)
states_clone = [st.clone().detach().requires_grad_(False) for st in states]
old_responsible_outputs = self.old_policy(self._action_index, self._action_index,
*states_clone).gather(1, actions).detach_()
responsible_outputs = self.policy(self._action_index, self._action_index, *states).gather(1, actions)
ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
loss = aggregate(torch.min(ratio * rewards, clamped_ratio * rewards)).neg()
return loss

def reset(self):
self.policy.reset_cache()
self.old_policy.reset_cache()

def multi_act(self, state, argmax_only=True) -> typing.Union[typing.Tuple[np.ndarray, np.ndarray], np.ndarray]:
self.policy.eval()
with torch.no_grad():
action_values = self.policy(self._action_index, self._action_index, *state).detach()
argmax = action_values.argmax(1).cpu().numpy()
if argmax_only:
return argmax
return action_values, argmax

def step(self, state, action, agent_done, collision, next_state, step_reward=0, collision_reward=-2):
agent_count = len(agent_done[0]) - 1
self.stack[0].append(state)
self.stack[1].append(action)
self.stack[2].append([[done[idx] for idx in range(agent_count)] for done in agent_done])
self.stack[3].append(collision)
self.stack[4].append(next_state)

if MINI_BACKWARD or len(self.stack[0]) >= UPDATE_EVERY:
action = torch.cat(self.stack[1]).to(device).unsqueeze_(1)
agent_done = np.array(self.stack[2])
collision = np.array(self.stack[3])
reward = np.where(agent_done, 1, np.where(collision, collision_reward, step_reward))
reward = torch.tensor(reward, device=device, dtype=torch.float).flatten(0, 1).unsqueeze_(1)
state = tuple(torch.cat(st, 0) for st in zip(*self.stack[0]))
next_state = tuple(torch.cat(st, 0) for st in zip(*self.stack[4]))
agent_done = torch.as_tensor(agent_done, device=device, dtype=torch.int8).flatten(0, 1).unsqueeze_(1)
self.stack = [[] for _ in range(6)]
self.tensor_stack.append((state, action, reward, next_state, agent_done))
if len(self.tensor_stack) >= EPOCHS:
tensor_stack = (torch.cat(t, 0) if isinstance(t[0], torch.Tensor)
else tuple(torch.cat(sub_t, 0) for sub_t in zip(*t))
for t in zip(*self.tensor_stack))
del self.tensor_stack[0]
self.learn(*tensor_stack)

def learn(self, states, actions, rewards, next_states, done):
if MINI_BACKWARD:
self.idx = (self.idx + 1) % UPDATE_EVERY

self.policy.train()

loss = self.loss(states, actions, next_states, rewards, done)

if MINI_BACKWARD:
loss = loss / UPDATE_EVERY

loss.backward()

if not MINI_BACKWARD or self.idx == 0:
if self._policy_update:
self.old_policy.load_state_dict(self.policy.state_dict())
self.optimizer.step()
self.optimizer.zero_grad()
if self._soft_update:
for target_param, local_param in zip(self.old_policy.parameters(), self.policy.parameters()):
target_param.data.copy_(DQN_TAU * local_param.data +
(1.0 - DQN_TAU) * target_param.data)

def save(self, path, *data):
torch.save(self.policy.state_dict(), path / 'dqn/model_checkpoint.local')
torch.save(self.old_policy.state_dict(), path / 'dqn/model_checkpoint.target')
torch.save(self.optimizer.state_dict(), path / 'dqn/model_checkpoint.optimizer')
with open(path / 'dqn/model_checkpoint.meta', 'wb') as file:
pickle.dump(data, file)

def load(self, path, *defaults):
loc = {} if torch.cuda.is_available() else {'map_location': torch.device('cpu')}
try:
print("Loading model from checkpoint...")
dqn = os.path.join(path, 'dqn')
self.policy.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.local'), **loc))
self.old_policy.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.target'), **loc))
self.optimizer.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.optimizer'), **loc))
with open(os.path.join(dqn, 'model_checkpoint.meta'), 'rb') as file:
return pickle.load(file)
except Exception as exc:
import traceback
traceback.print_exc()
print(f"Got exception {exc} loading model data. Possibly no checkpoint found.")
return defaults
Loading