diff --git a/.gitignore b/.gitignore index 2a88583..6f4d625 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,24 @@ venv/ .DS_Store *.pyc +/src/a.out +/index.html +/checkpoints/dqn/dqn0/loss.txt +/.idea/** +**model_checkpoint* +/.idea/modules.xml +/src/obs.so +/src/observation_utils.c +/src/observation_utils.h +/src/observation_utils.so +/.idea/other.xml +/.idea/inspectionProfiles/profiles_settings.xml +/.idea/inspectionProfiles/Project_Default.xml +/r/rail_networks_35x40x40.pkl +/r/rail_networks_sum.pkl +/r/schedules_35x40x40.pkl +/r/schedules_sum.pkl +/.idea/vcs.xml +*.c +*.so +*.out \ No newline at end of file diff --git a/README.md b/README.md index 8694333..a494a6d 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,70 @@ # flatland-training -This repo contains an optimized version of flatland-rl's `flatland.envs.observations.TreeObsForRailEnv`. Tree-based observations allow RL models to learn much more quickly than the global observations do, but flatland's built-in TreeObsForRailEnv is kind of slow, so I wrote a faster version! This repo also contains an optimized version of [https://gitlab.aicrowd.com/flatland/baselines/blob/master/utils/observation_utils.py](https://gitlab.aicrowd.com/flatland/baselines/blob/master/utils/observation_utils.py), which flattens and normalizes the tree observations into 1D numpy arrays that can be passed to a feed-forward network. +PyTorch solution for [flatland-2020](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/) +## Implementation -# Setup -## Create venv -`python3.7 -m venv venv`\ -`source venv/bin/activate` +This repository contains three major modules. -Verify python version is correct with: `python -V`\ -Should return `Python 3.7.something` +### Getting Started -## Install Requirements -`pip install -r requirements.txt` +#### Setup +Before following along here, please note that there is an `install.sh` script which executes all the commands here.\ +First, create virtual environment using `python3.7 -m venv venv && source venv/bin/activate`.\ +Then install the requirements with `python3 -m pip install -r requirements.txt`.\ +It is recommended to verify that the installation was successful by first checking the python version and then attempting an import of all required non-standard packages. +```bash +$ python3 --version +Python 3.7.6 +$ python3 -c "import torch, torch_optimizer, numpy, cython, flatland, gym, tqdm; print('Successfully imported packages')" +Successfully imported packages +``` +Lastly perhaps the most crucial step. It requires `gcc-7`, as no other version works. On Debian/Ubuntu, it can be installed using the apt package manager by running `apt install gcc-7`.\ +Once that's done, the python code can be compiled using cython. Compilation is done by first moving into the source folder and then executing cythonize.sh via `cd src && bash cythonize.sh`. -# Generate Railways -This script will precompute a bunch of railway maps to make training faster.\ +#### Generate Environments -`python src/generate_railways.py` +For better training performance, one can optionally generate the environments used to train the network _before_ training it. This way training is much faster as training data doesn't have to be regenerated repeatedly but instead gets loaded once on startup.\ +To start the generation of environments and their respective railways, use `python3 src/generate_railways.py --width 50`. The command will generate 50x50 grids of cities, rails and trains. -This will run for quite a long time, go get some tea...\ -But also it's fine to stop it after it completes at least one round if you just want to test things out and make sure they run.\ -If you don't care about the speedup, you can run `python src/train.py --load-railways=False` to generate railways on the fly during training instead. +#### Run Training +Finally, it's time to train the model. You can do so by running `python3 src/train.py`, which will train a basic cnn using the "local observation" method. ![https://flatland.aicrowd.com/getting-started/env.html](https://i.imgur.com/oo8EIYv.png)\ +Not only global and tree observation, but also many model parameters are implemented as well. To find out more about them, add the `--help` flag. -# Run Training -`python src/train.py` +### Structure -This will begin training one or more agents in the flatland environment.\ -This file has a lot of parameters that can be set to do different things.\ -To see all the options, use the `--help` command line argument. +Currently the code is structured in huge monolithic files. + +| Name | Description | +|----|----| +| agent.py | Reward and trainings algorithm, as well as some hyper parameters (such as batch size and learning rate)| +| generate_railways.py | Script to pre-compute and generate railways from the command line. See [#Generate Environments](#Generate-Environments)| +| model.py | PyTorch definition of tree-observation and local-observation models| +| observation_utils.pyx | Agent observation utilities called by environment to create training observation | +| rail_env.pyx | Cython-port of flatland-rl RailEnv | +| railway_utils.py | Utility script to handle creation of and iterators over railways | +| train.py | Core training loop | + +### Future Work + +The current implementation has many holes. One of them is the very poor performance received when controlling many (>10) agents at once.\ +We tackle this issue from multiple sites at once. If you would like to participate in this team, open an issue, pull request or join us on [discord](https://discord.gg/mP72wbE). +Our current approaches are listed below: + +* **Observation, Model**: + * Tree observation, graph neural networks + * Tree observation, fully-connected networks + * Tree observation, transformer + * Local observation, cnn + * Global observation, cnn +* **Teaching algorithm**: + * PPO + * (Double-) DQN +* **Misc. Freebies**: + * Epsilon-Greedy + * Multiprocessing + * Inter-agent communication + +If you are working on one of these tasks or would like to do so, please open an issue or pull request to let others know about it. Once seen, it will be added to the main repository. \ No newline at end of file diff --git a/checkpoints/dqn/dqn0/README.md b/checkpoints/dqn/dqn0/README.md new file mode 100644 index 0000000..7792019 --- /dev/null +++ b/checkpoints/dqn/dqn0/README.md @@ -0,0 +1 @@ +DQN checkpoints will be saved here diff --git a/checkpoints/ppo/README.md b/checkpoints/ppo/README.md deleted file mode 100644 index 700b9ec..0000000 --- a/checkpoints/ppo/README.md +++ /dev/null @@ -1 +0,0 @@ -PPO checkpoints will be saved here diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..7b752d2 --- /dev/null +++ b/install.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +function check_python { +python3 -c "import torch, torch_optimizer, numpy, cython, flatland, gym, tqdm; print('Successfully imported packages')" 2>/dev/null\ +|| (python3.7 -m venv venv && source venv/bin/activate && python3 -m pip install -r requirements.txt) || check_python +} + +check_python + +echo "Checking for GCC-7" + +gcc-7 --version > /dev/null || sudo apt install gcc-7 + +echo "Compiling source" +cd src && source cythonize.sh > /dev/null 2>/dev/null +cd .. \ No newline at end of file diff --git a/railroads/README.md b/railroads/README.md deleted file mode 100644 index 5baf854..0000000 --- a/railroads/README.md +++ /dev/null @@ -1 +0,0 @@ -Generated railroads will be saved here diff --git a/requirements.txt b/requirements.txt index 2fc4cc6..8c292f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ -argparse==1.4.0 -flatland-rl==2.2.1 -numpy==1.18.5 -torch==1.5.0 -opencv-python==4.2.0.34 -Pillow==7.1.2 -tqdm==4.46.1 -tensorboardX==2.0 +torch +torch-optimizer +numpy +cython +flatland-rl +gym +tqdm \ No newline at end of file diff --git a/src/agent.py b/src/agent.py new file mode 100644 index 0000000..0df4269 --- /dev/null +++ b/src/agent.py @@ -0,0 +1,224 @@ +import math +import pickle +import typing + +import numpy as np +import torch +from torch_optimizer import Yogi as Optimizer + +try: + from .model import QNetwork, ConvNetwork, init, GlobalStateNetwork, TripleClassificationHead, NAFHead +except: + from model import QNetwork, ConvNetwork, init, GlobalStateNetwork, TripleClassificationHead, NAFHead +import os + +BATCH_SIZE = 256 +CLIP_FACTOR = 0.2 +LR = 1e-4 +UPDATE_EVERY = 1 +CUDA = True +MINI_BACKWARD = False +DQN_TAU = 1e-3 +EPOCHS = 1 + +device = torch.device("cuda:0" if CUDA and torch.cuda.is_available() else "cpu") + + +@torch.jit.script +def aggregate(loss: torch.Tensor): + maximum, _ = loss.max(0) + return maximum.sum() + + +@torch.jit.script +def mse(in_x, in_y): + return aggregate((in_x - in_y).square()) + + +@torch.jit.script +def dqn_target(rewards, targets_next, done): + return rewards + 0.998 * targets_next * (1 - done) + + +class Agent(torch.nn.Module): + def __init__(self, state_size, action_size, model_depth, hidden_factor, kernel_size, squeeze_heads, decoder_depth, + model_type=0, softmax=True, debug=True, loss_type='PPO'): + super(Agent, self).__init__() + self.action_size = action_size + + # Q-Network + if model_type == 1: # Global/Local + network = ConvNetwork + elif model_type == 0: # Tree + network = QNetwork + else: # Global State + network = GlobalStateNetwork + if loss_type in ('PPO', 'DQN'): + tail = TripleClassificationHead(hidden_factor, action_size) + else: + tail = NAFHead(hidden_factor, action_size) + self.policy = network(state_size, + hidden_factor, + model_depth, + kernel_size, + squeeze_heads, + decoder_depth, + tail=tail, + softmax=softmax).to(device) + self.old_policy = network(state_size, + hidden_factor, + model_depth, + kernel_size, + squeeze_heads, + decoder_depth, + tail=tail, + softmax=softmax, + debug=False).to(device) + if debug: + print(self.policy) + + parameters = sum(np.prod(p.size()) for p in filter(lambda p: p.requires_grad, self.policy.parameters())) + digits = int(math.log10(parameters)) + number_string = " kMGTPEZY"[digits // 3] + + print(f"[DEBUG/MODEL] Training with {parameters * 10 ** -(digits // 3 * 3):.1f}" + f"{number_string} parameters") + self.policy.apply(init) + try: + self.policy = torch.jit.script(self.policy) + self.old_policy = torch.jit.script(self.old_policy) + except: + import traceback + traceback.print_exc() + print("NO JIT") + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer = Optimizer(self.policy.parameters(), lr=LR, weight_decay=1e-2) + + # Replay memory + self.stack = [[] for _ in range(6)] + self.t_step = 0 + self.idx = 1 + self.tensor_stack = [] + self._policy_update = loss_type in ("PPO",) + self._soft_update = loss_type in ("DQN", "NAF") + + self._action_index = torch.zeros(1) + self._value_index = torch.zeros(1) + 1 + self._triangular_index = torch.zeros(1) + 2 + + self.loss = getattr(self, f'_{loss_type.lower()}_loss') + + def _dqn_loss(self, states, actions, next_states, rewards, done): + actions = actions.argmax(1) + expected = self.policy(self._action_index, self._action_index, *states).gather(1, actions) + best_action = self.policy(self._action_index, self._action_index, next_states).argmax(1) + targets_next = self.old_policy(self._action_index, self._action_index, + *next_states).gather(1, best_action.unsqueeze(1)) + targets = dqn_target(rewards, targets_next, done) + loss = mse(expected, targets) + return loss + + def _naf_loss(self, states, actions, next_states, rewards, done): + targets_next = self.old_policy(self._value_index, self._action_index, next_states) + state_action_values = self.policy(self._triangular_index, actions, states) + targets = dqn_target(rewards, targets_next, done) + loss = mse(state_action_values, targets) + return loss + + def _ppo_loss(self, states, actions, next_states, rewards, done): + _ = next_states + _ = done + actions = actions.argmax(1) + states_clone = [st.clone().detach().requires_grad_(False) for st in states] + old_responsible_outputs = self.old_policy(self._action_index, self._action_index, + *states_clone).gather(1, actions).detach_() + responsible_outputs = self.policy(self._action_index, self._action_index, *states).gather(1, actions) + ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) + loss = aggregate(torch.min(ratio * rewards, clamped_ratio * rewards)).neg() + return loss + + def reset(self): + self.policy.reset_cache() + self.old_policy.reset_cache() + + def multi_act(self, state, argmax_only=True) -> typing.Union[typing.Tuple[np.ndarray, np.ndarray], np.ndarray]: + self.policy.eval() + with torch.no_grad(): + action_values = self.policy(self._action_index, self._action_index, *state).detach() + argmax = action_values.argmax(1).cpu().numpy() + if argmax_only: + return argmax + return action_values, argmax + + def step(self, state, action, agent_done, collision, next_state, step_reward=0, collision_reward=-2): + agent_count = len(agent_done[0]) - 1 + self.stack[0].append(state) + self.stack[1].append(action) + self.stack[2].append([[done[idx] for idx in range(agent_count)] for done in agent_done]) + self.stack[3].append(collision) + self.stack[4].append(next_state) + + if MINI_BACKWARD or len(self.stack[0]) >= UPDATE_EVERY: + action = torch.cat(self.stack[1]).to(device).unsqueeze_(1) + agent_done = np.array(self.stack[2]) + collision = np.array(self.stack[3]) + reward = np.where(agent_done, 1, np.where(collision, collision_reward, step_reward)) + reward = torch.tensor(reward, device=device, dtype=torch.float).flatten(0, 1).unsqueeze_(1) + state = tuple(torch.cat(st, 0) for st in zip(*self.stack[0])) + next_state = tuple(torch.cat(st, 0) for st in zip(*self.stack[4])) + agent_done = torch.as_tensor(agent_done, device=device, dtype=torch.int8).flatten(0, 1).unsqueeze_(1) + self.stack = [[] for _ in range(6)] + self.tensor_stack.append((state, action, reward, next_state, agent_done)) + if len(self.tensor_stack) >= EPOCHS: + tensor_stack = (torch.cat(t, 0) if isinstance(t[0], torch.Tensor) + else tuple(torch.cat(sub_t, 0) for sub_t in zip(*t)) + for t in zip(*self.tensor_stack)) + del self.tensor_stack[0] + self.learn(*tensor_stack) + + def learn(self, states, actions, rewards, next_states, done): + if MINI_BACKWARD: + self.idx = (self.idx + 1) % UPDATE_EVERY + + self.policy.train() + + loss = self.loss(states, actions, next_states, rewards, done) + + if MINI_BACKWARD: + loss = loss / UPDATE_EVERY + + loss.backward() + + if not MINI_BACKWARD or self.idx == 0: + if self._policy_update: + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer.step() + self.optimizer.zero_grad() + if self._soft_update: + for target_param, local_param in zip(self.old_policy.parameters(), self.policy.parameters()): + target_param.data.copy_(DQN_TAU * local_param.data + + (1.0 - DQN_TAU) * target_param.data) + + def save(self, path, *data): + torch.save(self.policy.state_dict(), path / 'dqn/model_checkpoint.local') + torch.save(self.old_policy.state_dict(), path / 'dqn/model_checkpoint.target') + torch.save(self.optimizer.state_dict(), path / 'dqn/model_checkpoint.optimizer') + with open(path / 'dqn/model_checkpoint.meta', 'wb') as file: + pickle.dump(data, file) + + def load(self, path, *defaults): + loc = {} if torch.cuda.is_available() else {'map_location': torch.device('cpu')} + try: + print("Loading model from checkpoint...") + dqn = os.path.join(path, 'dqn') + self.policy.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.local'), **loc)) + self.old_policy.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.target'), **loc)) + self.optimizer.load_state_dict(torch.load(os.path.join(dqn, 'model_checkpoint.optimizer'), **loc)) + with open(os.path.join(dqn, 'model_checkpoint.meta'), 'rb') as file: + return pickle.load(file) + except Exception as exc: + import traceback + traceback.print_exc() + print(f"Got exception {exc} loading model data. Possibly no checkpoint found.") + return defaults diff --git a/src/cat.py b/src/cat.py new file mode 100644 index 0000000..a36b8e2 --- /dev/null +++ b/src/cat.py @@ -0,0 +1,50 @@ +import os +import pickle +import random + +import tqdm + + +def main(bucket0, bucket1): + files = sorted([i for i in os.listdir() if i.endswith('pkl') and 'sum' not in i]) + print(f'Concatenating {", ".join(files)}') + + buckets = [{}, {}] + for fname in tqdm.tqdm(files, ncols=120, leave=False): + with open(fname, 'rb') as f: + try: + dat = pickle.load(f) + except Exception as e: + print(f'Caught {e} while processing {fname}') + dat = list(dat) + try: + _, _ = dat[0] + except: + buckets[1][fname.split('_')[-1].split('.pkl')[0]] = dat[:] + else: + buckets[0][fname.split('_')[-1].split('.pkl')[0]] = dat[:] + + def _get_itm(idx): + items = sorted(list(buckets[idx].items())) + random.seed(0) + random.shuffle(items) + names, items = list(zip(*items)) + print(f"First, Last in sequence: {names[0]}, {names[-1]}") + print(f"Random number _after_ shuffling (to check for seed consitency): {random.random()}") + buckets[idx] = [itm for lst in items for itm in lst] + random.shuffle(buckets[idx]) + + def _dump(idx, dump_name: str): + dump_name += 'sum.pkl' + with open(dump_name, 'wb') as f: + pickle.dump(buckets[idx], f) + print(f'Dumped {len(buckets[idx])} items from {len(files)} sources to {dump_name}') + + _get_itm(0) + _get_itm(1) + _dump(0, bucket0) + _dump(1, bucket1) + + +if __name__ == '__main__': + main('rail_networks_', 'schedules_') diff --git a/src/cythonize.sh b/src/cythonize.sh new file mode 100644 index 0000000..f44bc17 --- /dev/null +++ b/src/cythonize.sh @@ -0,0 +1,13 @@ +function compile { + file=${1} + cython "$file.pyx" -3 -Wextra -D + flags="$file.c `python3-config --cflags --ldflags --includes --libs` -I`python -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` -fno-lto -pthread -fPIC -fwrapv -pipe -march=native -mtune=native -Ofast -msse2 -msse4.2 -shared -o $file.so" + echo "Executing gcc with $flags" + (gcc-9 $flags) || (gcc-7 $flags) + echo "Testing compilation.." + python3 -c "import $file" + echo +} + +compile observation_utils +compile rail_env diff --git a/src/dqn/__init__.py b/src/dqn/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/dqn/agent.py b/src/dqn/agent.py deleted file mode 100644 index 6e1e3f9..0000000 --- a/src/dqn/agent.py +++ /dev/null @@ -1,117 +0,0 @@ -import copy -import random -import pickle -import torch -import torch.nn.functional as F - -from dqn.model import QNetwork -from replay_memory import ReplayBuffer - -BUFFER_SIZE = 500_000 -BATCH_SIZE = 512 -GAMMA = 0.998 -TAU = 1e-3 -LR = 0.5e-4 -UPDATE_EVERY = 40 - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Agent: - def __init__(self, state_size, action_size, num_agents, double_dqn=True): - self.action_size = action_size - self.double_dqn = double_dqn - - # Q-Network - self.qnetwork_local = QNetwork(state_size, action_size).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - self.optimizer = torch.optim.Adam(self.qnetwork_local.parameters(), lr=LR) - - # Replay memory - self.memory = ReplayBuffer(BUFFER_SIZE) - self.num_agents = num_agents - self.t_step = 0 - - def reset(self): - self.finished = [False] * self.num_agents - - - # Decide on an action to take in the environment - - def act(self, state, eps=0.): - state = torch.from_numpy(state).float().unsqueeze(0).to(device) - self.qnetwork_local.eval() - with torch.no_grad(): - action_values = self.qnetwork_local(state) - - # Epsilon-greedy action selection - if random.random() > eps: - return torch.argmax(action_values).item() - else: return torch.randint(self.action_size, ()).item() - - - # Record the results of the agent's action and update the model - - def step(self, handle, state, action, next_state, agent_done, episode_done, collision): - if not self.finished[handle]: - if agent_done: - reward = 1 - elif collision: - reward = -5 - else: reward = -.1 - - # Save experience in replay memory - self.memory.push(state, action, reward, next_state, agent_done or episode_done) - self.finished[handle] = agent_done or episode_done - - # Perform a gradient update every UPDATE_EVERY time steps - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 20: - self.learn(*self.memory.sample(BATCH_SIZE, device)) - - - def learn(self, states, actions, rewards, next_states, dones): - self.qnetwork_local.train() - - # Get expected Q values from local model - Q_expected = self.qnetwork_local(states).gather(1, actions) - - if self.double_dqn: - Q_best_action = self.qnetwork_local(next_states).argmax(1) - Q_targets_next = self.qnetwork_target(next_states).gather(1, Q_best_action.unsqueeze(-1)) - else: Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) - - # Compute Q targets for current states - Q_targets = rewards + GAMMA * Q_targets_next * (1 - dones) - - # Compute loss and perform a gradient step - self.optimizer.zero_grad() - loss = F.mse_loss(Q_expected, Q_targets) - loss.backward() - self.optimizer.step() - - # Update the target network parameters to `tau * local.parameters() + (1 - tau) * target.parameters()` - for target_param, local_param in zip(self.qnetwork_target.parameters(), self.qnetwork_local.parameters()): - target_param.data.copy_(TAU * local_param.data + (1.0 - TAU) * target_param.data) - - - # Checkpointing methods - - def save(self, path, *data): - torch.save(self.qnetwork_local.state_dict(), path / 'dqn/model_checkpoint.local') - torch.save(self.qnetwork_target.state_dict(), path / 'dqn/model_checkpoint.target') - torch.save(self.optimizer.state_dict(), path / 'dqn/model_checkpoint.optimizer') - with open(path / 'dqn/model_checkpoint.meta', 'wb') as file: - pickle.dump(data, file) - - def load(self, path, *defaults): - try: - print("Loading model from checkpoint...") - self.qnetwork_local.load_state_dict(torch.load(path / 'dqn/model_checkpoint.local')) - self.qnetwork_target.load_state_dict(torch.load(path / 'dqn/model_checkpoint.target')) - self.optimizer.load_state_dict(torch.load(path / 'dqn/model_checkpoint.optimizer')) - with open(path / 'dqn/model_checkpoint.meta', 'rb') as file: - return pickle.load(file) - except: - print("No checkpoint file was found") - return defaults diff --git a/src/dqn/model.py b/src/dqn/model.py deleted file mode 100644 index f64d641..0000000 --- a/src/dqn/model.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class QNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128): - super(QNetwork, self).__init__() - - self.fc1_val = nn.Linear(state_size, hidsize1) - self.fc2_val = nn.Linear(hidsize1, hidsize2) - self.fc3_val = nn.Linear(hidsize2, 1) - - self.fc1_adv = nn.Linear(state_size, hidsize1) - self.fc2_adv = nn.Linear(hidsize1, hidsize2) - self.fc3_adv = nn.Linear(hidsize2, action_size) - - def forward(self, x): - x = x.view(x.shape[0], -1) - - val = F.relu(self.fc1_val(x)) - val = F.relu(self.fc2_val(val)) - val = self.fc3_val(val) - - adv = F.relu(self.fc1_adv(x)) - adv = F.relu(self.fc2_adv(adv)) - adv = self.fc3_adv(adv) - - return val + adv - adv.mean() diff --git a/src/generate_railways.py b/src/generate_railways.py index 0f0a6ba..0034d91 100755 --- a/src/generate_railways.py +++ b/src/generate_railways.py @@ -1,45 +1,55 @@ -import time +import argparse +import multiprocessing import pickle -import numpy as np -from tqdm import tqdm from pathlib import Path -from flatland.envs.rail_generators import sparse_rail_generator, complex_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator, complex_schedule_generator +import numpy as np +from tqdm import tqdm -from railway_utils import create_random_railways +try: + from .railway_utils import create_random_railways +except: + from railway_utils import create_random_railways project_root = Path(__file__).resolve().parent.parent -rail_generator, schedule_generator = create_random_railways(project_root) +parser = argparse.ArgumentParser(description="Train an agent in the flatland environment") -width, height = 50, 50 -n_agents = 5 +parser.add_argument("--width", type=int, default=35, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--factor", type=int, default=2, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--base", type=float, default=1.1, help="Decay factor for epsilon-greedy exploration") +flags = parser.parse_args() +rail_generator, schedule_generator = create_random_railways(flags.width, flags.base, flags.factor) # Load in any existing railways for this map size so we don't overwrite them +network = project_root / f'railroads/rail_networks_{flags.width}_{flags.factor}.pkl' +sched = project_root / f'railroads/schedules_{flags.width}_{flags.factor}.pkl' try: - with open(project_root / f'railroads/rail_networks_{n_agents}x{width}x{height}.pkl', 'rb') as file: + with open(network, 'rb') as file: rail_networks = pickle.load(file) - with open(project_root / f'railroads/schedules_{n_agents}x{width}x{height}.pkl', 'rb') as file: + with open(sched, 'rb') as file: schedules = pickle.load(file) print(f"Loading {len(rail_networks)} railways...") except: rail_networks, schedules = [], [] -# Generate 10000 random railways in 100 batches of 100 -for _ in range(100): - for i in tqdm(range(100), ncols=120, leave=False): - map, info = rail_generator(width, height, n_agents, num_resets=0, np_random=np.random) - schedule = schedule_generator(map, n_agents, info['agents_hints'], num_resets=0, np_random=np.random) +def do(schedules: list, rail_networks: list): + for _ in range(100): + map, info = rail_generator(flags.width, 1, 1, num_resets=0, np_random=np.random) + schedule = schedule_generator(map, 1, info['agents_hints'], num_resets=0, np_random=np.random) rail_networks.append((map, info)) schedules.append(schedule) + return + - print(f"Saving {len(rail_networks)} railways") - with open(project_root / f'railroads/rail_networks_{n_agents}x{width}x{height}.pkl', 'wb') as file: - pickle.dump(rail_networks, file) - with open(project_root / f'railroads/schedules_{n_agents}x{width}x{height}.pkl', 'wb') as file: - pickle.dump(schedules, file) +for _ in tqdm(range(500), ncols=150, leave=False): + do(schedules, rail_networks) + with open(network, 'wb') as file: + pickle.dump(rail_networks, file, protocol=4) + with open(sched, 'wb') as file: + pickle.dump(schedules, file, protocol=4) +print(f"Saved {len(rail_networks)} railways") print("Done") diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..dfb0714 --- /dev/null +++ b/src/model.py @@ -0,0 +1,299 @@ +import typing + +import torch + + +@torch.jit.script +def mish(fn_input: torch.Tensor) -> torch.Tensor: + return fn_input * torch.tanh(torch.nn.functional.softplus(fn_input)) + + +@torch.jit.script +def nothing(x): + return x + + +class Mish(torch.nn.Module): + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + return mish(fn_input) + + +class SeparableConvolution(torch.nn.Module): + def __init__(self, in_features, out_features, kernel_size: typing.Union[int, tuple], + padding: typing.Union[int, tuple] = 0, dilation: typing.Union[int, tuple] = 1, + bias=False, dim=1, stride: typing.Union[int, tuple] = 1, dropout=0.25): + super(SeparableConvolution, self).__init__() + self.depthwise = kernel_size > 1 if isinstance(kernel_size, int) else any(k > 1 for k in kernel_size) + conv = getattr(torch.nn, f'Conv{dim}d') + norm = getattr(torch.nn, f'InstanceNorm{dim}d') + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * dim + if self.depthwise: + self.depthwise_conv = conv(in_features, in_features, + kernel_size, + padding=padding, + groups=in_features, + dilation=dilation, + bias=False, + stride=stride) + self.mid_norm = norm(in_features, affine=True) + else: + self.depthwise_conv = nothing + self.mid_norm = nothing + self.pointwise_conv = conv(in_features, out_features, (1,) * dim, bias=bias) + self.str = (f'SeparableConvolution({in_features}, {out_features}, {kernel_size}, ' + + f'dilation={dilation}, padding={padding}, stride={stride})') + self.dropout = dropout * (in_features == out_features and (stride == 1 or all(s == 1 for s in stride))) + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + if torch.rand(1) < self.dropout: + return fn_input + if self.depthwise: + fn_input = self.mid_norm(self.depthwise_conv(fn_input)) + return self.pointwise_conv(fn_input) + + def __str__(self): + return self.str + + def __repr__(self): + return self.str + + +class BasicBlock(torch.nn.Module): + def __init__(self, in_features, out_features, stride, init_norm=False, message_box=None, double=True, + agent_dim=True): + super(BasicBlock, self).__init__() + self.activate = init_norm + self.double = double + self.agent_dim = agent_dim + dim = 2 + agent_dim + norm = getattr(torch.nn, f'InstanceNorm{dim}d') + kernel = (3, 3) + ((1,) if agent_dim else ()) + pad = (1, 1) + ((0,) if agent_dim else ()) + stride = (stride, stride) + ((1,) if agent_dim else ()) + self.init_norm = norm(in_features, affine=True) if init_norm else nothing + self.init_conv = SeparableConvolution(in_features, out_features, kernel, pad, stride=stride, dim=dim) + if double: + self.mid_norm = norm(out_features, affine=True) + self.end_conv = SeparableConvolution(out_features, out_features, kernel, pad, dim=dim) + else: + self.mid_norm = nothing + self.end_conv = nothing + self.shortcut = torch.nn.Sequential() + if stride[0] > 1: + self.shortcut.add_module("1", getattr(torch.nn, f"MaxPool{dim}d")(kernel, stride, padding=pad)) + if in_features != out_features: + self.shortcut.add_module("2", getattr(torch.nn, f"Conv{dim}d")(in_features, out_features, 1)) + self.message_box = int(out_features ** 0.5) if message_box is None else message_box + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + out = self.init_conv(fn_input if self.activate is None else mish(self.init_norm(fn_input))) + if self.double: + out = self.end_conv(mish(self.mid_norm(out))) + if self.agent_dim and self.message_box > 0: + out[:, :self.message_box] = out[:, :self.message_box].mean(-1, + keepdim=True).expand(-1, -1, -1, -1, + fn_input.size(-1)) + srt = self.shortcut(fn_input) + out = out + srt + return out + + +class NAFHead(torch.nn.Module): + def __init__(self, hidden_size, action_size): + super(NAFHead, self).__init__() + self.action_size = action_size + self.action = torch.nn.Conv1d(hidden_size, action_size, 1) + self.value = torch.nn.Conv1d(hidden_size, 1, 1) + self.triangular = torch.nn.Conv1d(hidden_size, action_size ** 2, 1) + self.diagonal_mask = torch.ones((action_size, action_size)).diag().diag().unsqueeze_(0) + + def forward(self, fn_input, idx, prev_action): + if idx == 0: + return self.action(fn_input) + if idx == 1: + return self.value(fn_input) + + actions, value = self.action(fn_input), self.value(fn_input) + + batch = fn_input.size(0) + triangular = self.triangular(fn_input).view(batch, self.action_size, self.action_size, -1) + triangular = triangular.tril() + triangular.mul(self.diagonal_mask).exp() + matrix = torch.bmm(triangular, triangular.transpose(1, 2)) + + action_difference = (prev_action - actions).unsqueeze(2) + advantage = torch.bmm(torch.bmm(action_difference.transpose(1, 2), matrix), + action_difference)[:, :, 0].div(2).neg() + + return advantage + value + + +class TripleClassificationHead(torch.nn.Module): + def __init__(self, hidden_size, action_size): + super(TripleClassificationHead, self).__init__() + self.linear = torch.nn.Conv1d(hidden_size, action_size, 1) + + def forward(self, fn_input, idx, prev_action): + _ = idx + _ = prev_action + return self.linear(fn_input) + + +class TailModel(torch.nn.Module): + def __init__(self, tail): + super(TailModel, self).__init__() + self.tail = nothing if tail is None else tail + self.no_tensor = torch.zeros(1) + + def _backbone(self, fn_input: torch.Tensor, potential_fn_input: torch.Tensor) -> torch.Tensor: + raise UserWarning("Has to be implemented by child class") + + def forward(self, idx, prev_action, state, rail): + out = self._backbone(state, rail) + return self.tail(out, idx, prev_action) + + @torch.jit.export + def reset_cache(self): + pass + + +class ConvNetwork(TailModel): + def __init__(self, state_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=True, debug=True, softmax=False, tail=None): + super(ConvNetwork, self).__init__(tail) + _ = state_size + state_size = 2 * 21 + self.net = torch.nn.ModuleList([BasicBlock(state_size if not i else hidden_size, + hidden_size, + 2, + init_norm=bool(i)) + for i in range(depth)]) + self.init_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) + + def _backbone(self, fn_input: torch.Tensor, trash: torch.Tensor) -> torch.Tensor: + _ = trash + out = fn_input + for module in self.net: + out = module(out) + out = out.mean((2, 3)) + out = mish(self.init_norm(out)) + return out + + +class DecoderBlock(torch.nn.Module): + def __init__(self, in_features, out_features, message_box=None, init_norm=True): + super(DecoderBlock, self).__init__() + self.init_norm = init_norm + self.residual = in_features == out_features + self.norm = torch.nn.InstanceNorm1d(in_features, affine=True) if init_norm else nothing + self.conv = torch.nn.Conv1d(in_features, out_features, 1, bias=False) + self.message_box = int(out_features ** 0.5) if message_box is None else message_box + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + if self.init_norm: + out = self.norm(fn_input) + out = mish(out) + else: + out = fn_input + out = self.conv(out) + if self.message_box > 0: + out[:, :self.message_box] = out[:, :self.message_box].mean(-1, keepdim=True).expand(-1, -1, + fn_input.size(-1)) + if self.residual: + return out + fn_input + return out + + +@torch.jit.script +def attention(tensor: torch.Tensor): + query, key, value = tensor.chunk(3, 1) + query = query.transpose(1, 2) # B, F, S -> B, S, F + key = torch.bmm(query, key).softmax(1) + value = torch.bmm(value, key) + return value + + +class GlobalStateNetwork(TailModel): + def __init__(self, state_size, hidden_size=15, depth=8, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=True, debug=True, softmax=False, memory_size=4, tail=None): + super(GlobalStateNetwork, self).__init__(tail) + _ = state_size + _ = kernel_size + _ = squeeze_heads + _ = cat + _ = debug + global_state_size = 2 * 16 + agent_state_size = 2 * 13 + + self.memory_size = memory_size + + self.net = torch.nn.Sequential(*[BasicBlock(global_state_size if not i else hidden_size, + hidden_size + (i == depth - 1) * hidden_size * 2, + 2, + init_norm=bool(i), + message_box=0, + double=False, + agent_dim=False) + for i in range(depth)]) + self.decoder_input = DecoderBlock(agent_state_size, 3 * hidden_size, 0, False) + self.decoder = torch.nn.ModuleList([DecoderBlock(hidden_size, 3 * hidden_size, 0, True) + for _ in range(1, decoder_depth)]) + self.end_norm = torch.nn.InstanceNorm1d(hidden_size, affine=True) + + self.memory_tensor = torch.nn.Parameter(torch.randn(1, 3 * hidden_size, memory_size)) + + def _backbone(self, state: torch.Tensor, rail: torch.Tensor) -> torch.Tensor: + inp = self.net(rail) + inp = inp.mean((2, 3), keepdim=True).squeeze(-1) + state = torch.cat([self.decoder_input(state) + inp, self.memory_tensor.expand(inp.size(0), -1, -1)], 2) + state = attention(state) + for block in self.decoder: + state = attention(block(state)) + state = state[:, :, :-self.memory_size] + out = self.end_norm(state) + return out + + +def init(module: torch.nn.Module): + if hasattr(module, "weight") and hasattr(module.weight, "data"): + if "norm" in module.__class__.__name__.lower() or ( + hasattr(module, "__str__") and "norm" in str(module).lower()): + torch.nn.init.uniform_(module.weight.data, 0.998, 1.002) + else: + torch.nn.init.orthogonal_(module.weight.data) + if hasattr(module, "bias") and hasattr(module.bias, "data"): + torch.nn.init.constant_(module.bias.data, 0) + + +class Residual(torch.nn.Module): + def __init__(self, features): + super(Residual, self).__init__() + self.norm = torch.nn.BatchNorm1d(features) + self.conv = torch.nn.Conv1d(features, 2 * features, 1) + + def forward(self, fn_input: torch.Tensor) -> torch.Tensor: + out, exc = self.conv(mish(self.norm(fn_input))).chunk(2, 1) + exc = exc.mean(dim=-1, keepdim=True).tanh() + fn_input = fn_input * exc + out = -out * exc + return fn_input + out + + +class QNetwork(TailModel): + def __init__(self, state_size, hidden_factor=16, depth=4, kernel_size=7, squeeze_heads=4, + decoder_depth=1, cat=False, debug=True, tail=None): + super(QNetwork, self).__init__(tail) + _ = depth + _ = kernel_size + _ = squeeze_heads + _ = cat + _ = debug + _ = decoder_depth + self.model = torch.nn.Sequential(torch.nn.Conv1d(2 * state_size, 11 * hidden_factor, 1, groups=11, bias=False), + Residual(11 * hidden_factor), + torch.nn.BatchNorm1d(11 * hidden_factor), + Mish()) + + def _backbone(self, fn_input: torch.Tensor, trash: torch.Tensor) -> torch.Tensor: + _ = trash + return self.model(fn_input) diff --git a/src/normalize_output_data.py b/src/normalize_output_data.py new file mode 100644 index 0000000..aa1de4a --- /dev/null +++ b/src/normalize_output_data.py @@ -0,0 +1,35 @@ +import torch + + +torch.jit.optimized_execution(True) + + +def wrap(data: torch.Tensor): + start = data[:, :, :6] + mid = data[:, :, 6] + + max0 = torch.where(start < 1000, start, torch.zeros_like(start)) + max0 = max0.max(dim=1, keepdim=True)[0] + max0.clamp_(min=1) + + max1 = torch.where(mid < 1000, mid, torch.zeros_like(mid)) + max1 = max1.max(dim=1, keepdim=True)[0] + max1.clamp_(min=1) + + min_mid = torch.where(mid >= 0, mid, torch.zeros_like(mid)) + min_obs = min_mid.min(dim=1, keepdim=True)[0] + + mid.sub_(min_obs) + max1.sub_(min_obs) + mid.div_(max1) + + start.div_(max0) + + data.clamp_(-1, 1) + + data[:, :, :6].sub_(data[:, :, :6].mean()) + data[:, :, 7:].sub_(data[:, :, 7:].mean()) + data.detach_() + + +wrap = torch.jit.script(wrap) diff --git a/src/observation_utils.py b/src/observation_utils.py deleted file mode 100644 index fc7e18c..0000000 --- a/src/observation_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -import numpy as np -from tree_observation import ACTIONS - -ZERO_NODE = np.array([0] * 11) # For Q-Networks -INF_DISTANCE_NODE = np.array([0] * 6 + [np.inf] + [0] * 4) # For policy networks - - -# Helper function to detect collisions -def is_collision(obs): - return obs is not None \ - and isinstance(obs.childs['L'], float) \ - and isinstance(obs.childs['R'], float) \ - and obs.childs['F'].num_agents_opposite_direction > 0 \ - and obs.childs['F'].dist_other_agent_encountered <= 1 \ - and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_unusable_switch - # and obs.childs['F'].dist_other_agent_encountered < obs.childs['F'].dist_to_next_branch - - -# Recursively create numpy arrays for each tree node -def create_tree_features(node, current_depth, max_depth, empty_node, data): - if node == -np.inf: - num_remaining_nodes = (4 ** (max_depth - current_depth + 1) - 1) // (4 - 1) - data.extend([empty_node] * num_remaining_nodes) - - else: - data.append(np.array(tuple(node)[:-2])) - if node.childs: - for direction in ACTIONS: - create_tree_features(node.childs[direction], current_depth + 1, max_depth, empty_node, data) - - return data - -# Normalize an observation to [0, 1] and then clip it to get rid of any infinite-valued features -def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False): - if fixed_radius > 0: - max_obs = fixed_radius - else: max_obs = np.max(obs[np.where(obs < 1000)], initial=1) + 1 - - min_obs = np.min(obs[np.where(obs >= 0)], initial=max_obs) if normalize_to_range else 0 - - if max_obs == min_obs: - return np.clip(obs / max_obs, clip_min, clip_max) - else: return np.clip((obs - min_obs) / np.abs(max_obs - min_obs), clip_min, clip_max) - - -# Normalize a tree observation -def normalize_observation(tree, max_depth, zero_center=True): - empty_node = ZERO_NODE if zero_center else INF_DISTANCE_NODE - data = np.concatenate(create_tree_features(tree, 0, max_depth, empty_node, [])).reshape((-1, 11)) - - obs_data = norm_obs_clip(data[:,:6].flatten()) - distances = norm_obs_clip(data[:,6], normalize_to_range=True) - agent_data = np.clip(data[:,7:].flatten(), -1, 1) - - if zero_center: - return np.concatenate((obs_data - obs_data.mean(), distances, agent_data - agent_data.mean())) - else: return np.concatenate((obs_data, distances, agent_data)) diff --git a/src/observation_utils.pyx b/src/observation_utils.pyx new file mode 100644 index 0000000..e26249e --- /dev/null +++ b/src/observation_utils.pyx @@ -0,0 +1,590 @@ +#!python +#cython: boundscheck=False +#cython: initializedcheck=False +#cython: nonecheck=False +#cython: wraparound=False +#cython: cdivision=True + +from collections import defaultdict + +cimport numpy as cnp +import numpy as np +import torch +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.observations import TreeObsForRailEnv + +cdef list ACTIONS = ['L', 'F', 'R', 'B'] +Node = TreeObsForRailEnv.Node + +cdef int positive_infinity = int(1e5) +cdef int negative_infinity = -positive_infinity + +def first(iterable): + for elem in iterable: + return elem + +cpdef bint _check_len1(tuple obj): + return len(obj) > 1 + +cpdef str get_action(int orientation, int direction): + return ACTIONS[(direction - orientation + 1 + 4) % 4] + +cpdef int get_direction(int orientation, int action): + if action == 1: + return (orientation - 1) % 4 + elif action == 3: + return (orientation + 1) % 4 + else: + return orientation + +cpdef set_array(tuple data, int width, cnp.ndarray arr, int base_index, int start_index): + arr[start_index:start_index + 4, base_index] = data + (data[0] / width, data[1] / width) + +cdef class RailNode: + cdef public dict edges + cdef public tuple position + cdef public tuple edge_directions + cdef public int is_target + def __init__(self, tuple position, tuple edge_directions, int is_target): + self.edges = {} + self.position = position + self.edge_directions = edge_directions + self.is_target = is_target + + def __repr__(self): + return f'RailNode({self.position}, {len(self.edges)})' + + +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16),\ + assuming 16 bits encoding of transitions. + + - obs_agents_state: A 3D array (map_height, map_width, 5) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and direction + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - fifth channel containing number of other agents ready to depart + + - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ + target and the positions of the other agents targets (flag only, no counter!). + """ + def __init__(self, data_type=np.int8, return_array=True): + super(GlobalObsForRailEnv, self).__init__() + self.size = 0 + self._custom_rail_obs = None + self.data_type = data_type + self.return_array = return_array + def reset(self): + if self._custom_rail_obs is None: + self._custom_rail_obs = np.zeros((1, self.env.height + 2 * self.size, self.env.width + 2 * self.size, 16)) + cdef cnp.ndarray out = np.array([[[[1 if digit == '1' else 0 + for digit in + f'{self.env.rail.get_full_transitions(i, j):016b}'] + for j in + range(self.env.width)] + for i in + range(self.env.height)]], + dtype=self.data_type) + if self.size > 0: + self._custom_rail_obs[0, self.size:self.env.height - self.size, self.size:self.env.width - self.size] = out + else: + self._custom_rail_obs = out + + def get_many(self, list trash): + cdef int agent_count = len(self.env.agents) + cdef cnp.ndarray obs_agents_state = np.zeros((agent_count, + self.env.height, + self.env.width, + 5), dtype=self.data_type) + cdef int i, agent_id + cdef tuple pos, agent_virtual_position + for agent_id, agent in enumerate(self.env.agents): + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + continue + + obs_agents_state[agent_id, :, :, 0:4] = -1 + + obs_agents_state[(agent_id,) + agent_virtual_position + (0,)] = agent.direction + + for i, other_agent in enumerate(self.env.agents): + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + # second to fourth channel only if in the grid + if other_agent.position is not None: + pos = (agent_id,) + other_agent.position + # second channel only for other agents + if i != agent_id: + obs_agents_state[pos + (1,)] = other_agent.direction + obs_agents_state[pos + (2,)] = other_agent.malfunction_data['malfunction'] + obs_agents_state[pos + (3,)] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + obs_agents_state[(agent_id,) + other_agent.initial_position + (4,)] += 1 + if self.return_array: + return obs_agents_state + return dict(enumerate(np.concatenate([np.repeat(self.rail_obs, agent_count, 0), obs_agents_state], -1))) + + +class LocalObsForRailEnv(GlobalObsForRailEnv): + def __init__(self, size=7, return_array=True): + super(LocalObsForRailEnv, self).__init__(return_array=return_array) + self.size = size + def get_many(self, list trash): + cdef int agent_count = len(self.env.agents) + obs_agents_state = np.zeros((agent_count, + self.size * 2 + 1, + self.size * 2 + 1, + 21), dtype=np.float32) + cdef int i, agent_id + cdef tuple agent_virtual_position + for agent_id, agent in enumerate(self.env.agents): + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + continue + x0, y0, x1, y1 = (agent_virtual_position[0], + agent_virtual_position[1], + agent_virtual_position[0] + 2 * self.size + 1, + agent_virtual_position[1] + 2 * self.size + 1) + obs_agents_state[agent_id, :, :, 5:] = self._custom_rail_obs[0, x0:x1, y0:y1] + + obs_agents_state[agent_id, :, :, 0:4] = -1 + + obs_agents_state[agent_id, :, :, 0] = agent.direction + + for i, other_agent in enumerate(self.env.agents): + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + # second to fourth channel only if in the grid + if other_agent.position is not None: + if i != agent_id: + obs_agents_state[agent_id, :, :, 1] = other_agent.direction + obs_agents_state[agent_id, :, :, 2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[agent_id, :, :, 3] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + init_pos = other_agent.initial_position + dist0 = agent_virtual_position[0] - init_pos[0] + dist1 = agent_virtual_position[1] - init_pos[1] + if abs(dist0) < self.size and abs(dist1) < self.size: + obs_agents_state[agent_id, dist0 + self.size, dist1 + self.size, 4] += 1 + if self.return_array: + return obs_agents_state + return dict(enumerate(obs_agents_state)) + + +class GlobalStateObs(GlobalObsForRailEnv): + def __init__(self, return_array=True): + super(GlobalStateObs, self).__init__(return_array=return_array) + def get_many(self, list trash): + cdef int agent_count = len(self.env.agents) + cdef cnp.ndarray obs_agents_state = np.zeros((13, agent_count), dtype=np.float32) + cdef int agent_id + cdef tuple agent_virtual_position + + for agent_id, agent in enumerate(self.env.agents): + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: # Done+Removed + continue + + obs_agents_state[0, agent_id] = agent.direction + obs_agents_state[1, agent_id] = agent.malfunction_data['malfunction'] + obs_agents_state[2, agent_id] = agent.speed_data['speed'] + obs_agents_state[3, agent_id] = agent.status + obs_agents_state[4, agent_id] = agent.moving + set_array(agent_virtual_position, self.env.width, obs_agents_state, agent_id, 5) + set_array(agent.target, self.env.width, obs_agents_state, agent_id, 9) + + if self.return_array: + return obs_agents_state, self._custom_rail_obs[0] + return dict(enumerate(obs_agents_state)) + + +class TreeObservation(ObservationBuilder): + def __init__(self, max_depth): + super().__init__() + self.max_depth = max_depth + self.observation_dim = 11 + + # Create a graph representation of the current rail network + + def reset(self): + self.target_positions = {agent.target: 1 for agent in self.env.agents} + self.edge_positions = defaultdict(list) # (cell.position, direction) -> [(start, end, direction, distance)] + self.edge_paths = defaultdict(list) # (node.position, direction) -> [(cell.position, direction)] + + # First, we find a node by starting at one of the agents and following the rails until we reach a junction + agent = first(self.env.agents) + cpdef tuple position = tuple(agent.initial_position) + cpdef int direction = agent.direction + cdef bint out + while True: + try: + out = self.is_junction(position) or self.is_target(position) + except IndexError: + break + if not out: + break + direction = first(self.get_possible_transitions(position, direction)) + position = get_new_position(position, direction) + + # Now we create a graph representation of the rail network, starting from this node + cdef tuple transitions = self.get_all_transitions(position) + cdef dict root_nodes = {t: RailNode(position, t, self.is_target(position)) for t in transitions if t} + self.graph = {(*position, d): root_nodes[t] for d, t in enumerate(transitions) if t} + + for transitions, node in root_nodes.items(): + for direction in transitions: + self.explore_branch(node, get_new_position(position, direction), direction) + + def explore_branch(self, RailNode node, tuple position, int direction): + cdef int original_direction = direction + cdef dict edge_positions = {} + cdef int distance = 1 + cdef int next_direction = 0 + cdef int idx = 0 + cdef tuple transition = tuple() + cdef tuple key = tuple() + + # Explore until we find a junction + while not self.is_junction(position) and not self.is_target(position): + next_direction = first(self.get_possible_transitions(position, direction)) + edge_positions[(*position, direction)] = (distance, next_direction) + position = get_new_position(position, next_direction) + direction = next_direction + distance += 1 + + # Create any nodes that aren't in the graph yet + cdef tuple transitions = self.get_all_transitions(position) + cdef bint is_target = self.is_target(position) + cdef dict nodes = {transition: RailNode(position, transition, is_target) + for idx, transition in enumerate(transitions) + if transition and (*position, idx) not in self.graph} + + for idx, transition in enumerate(transitions): + if transition in nodes: + self.graph[(*position, idx)] = nodes[transition] + + # Connect the previous node to the next one, and update self.edge_positions + cdef RailNode next_node = self.graph[(*position, direction)] + node.edges[original_direction] = (next_node, distance) + for key, (distance, next_direction) in edge_positions.items(): + self.edge_positions[key].append((node, next_node, original_direction, distance)) + self.edge_paths[node.position, original_direction].append((*key, next_direction)) + + # Call ourselves recursively since we're exploring depth-first + for transitions, node in nodes.items(): + for direction in transitions: + self.explore_branch(node, get_new_position(position, direction), direction) + + # Create a tree observation for each agent, based on the graph we created earlier + + def get_many(self, list handles=[]): + self.nodes_with_agents_going, self.edges_with_agents_going = {}, defaultdict(dict) + self.nodes_with_agents_coming, self.edges_with_agents_coming = {}, defaultdict(dict) + self.nodes_with_malfunctions, self.edges_with_malfunctions = {}, defaultdict(dict) + self.nodes_with_departures, self.edges_with_departures = {}, defaultdict(dict) + + cdef int direction = 0 + + # Create some lookup tables that we can use later to figure out how far away the agents are from each other. + for agent in self.env.agents: + if agent.status == RailAgentStatus.READY_TO_DEPART and agent.initial_position: + for direction in range(4): + if (*agent.initial_position, direction) in self.graph: + self.nodes_with_departures[(*agent.initial_position, direction)] = 1 + + for start, _, start_direction, distance in self.edge_positions[ + (*agent.initial_position, direction)]: + self.edges_with_departures[(*start.position, start_direction)][agent.handle] = distance + + if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position: + agent_key = (*agent.position, agent.direction) + for direction in range(4): + # # Check the nodes + if (*agent.position, direction) in self.graph: + node_dict = self.nodes_with_agents_going if direction == agent.direction else self.nodes_with_agents_coming + node_dict[(*agent.position, direction)] = agent.speed_data['speed'] + + # if len(self.graph[agent_key].edges) > 1: + # exit_direction = get_direction(agent.direction, agent.speed_data['transition_action_on_cellexit']) + # if agent.speed_data['position_fraction'] == 0 or exit_direction not in self.graph[agent_key].edges: # Agent still has options + # self.nodes_with_agents_going[(*agent.position, direction)] = agent.speed_data['speed'] + # else: # Agent has already decided + # coming_direction = (exit_direction + 2) % 4 + # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going + # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] + # else: + # exit_direction = first(self.graph[agent_key].edges.keys()) + # coming_direction = (exit_direction + 2) % 4 + # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going + # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] + + # Check the edges + if agent_key in self.edge_positions: + exit_direction = first(self.get_possible_transitions(agent.position, agent.direction)) + coming_direction = (exit_direction + 2) % 4 + edge_dict = self.edges_with_agents_coming if direction == coming_direction else self.edges_with_agents_going + if direction == agent.direction or direction == coming_direction: + for start, _, start_direction, distance in self.edge_positions[ + (*agent.position, direction)]: + edge_dict[(*start.position, start_direction)][agent.handle] = ( + distance, agent.speed_data['speed']) + + # Check for malfunctions + if agent.malfunction_data['malfunction']: + if (*agent.position, direction) in self.graph: + self.nodes_with_malfunctions[(*agent.position, direction)] = agent.malfunction_data[ + 'malfunction'] + + for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: + self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ + (distance, agent.malfunction_data['malfunction']) + cdef dict data = super().get_many(handles) + return data + #return tuple(data.values()) + + # Compute the observation for a single agent + def get(self, int handle): + agent = self.env.agents[handle] + cdef set visited_cells = set() + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_position = agent.target + else: + return None + + # The root node contains information about the agent itself + cdef int direction = 0 + cdef int distance = 0 + root_tree_node = Node(0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent_position, agent.direction)], + 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed'], 0, + {x: negative_infinity for x in ACTIONS}) + + # Then we build out the tree by exploring from this node + cdef tuple key = (*agent_position, agent.direction) + cdef RailNode node = RailNode(tuple(), tuple(), 0) + cdef RailNode prev_node = RailNode(tuple(), tuple(), 0) + if key in self.graph: # If we're sitting on a junction, branch out immediately + node = self.graph[key] + if len(node.edges) > 1: # Major node + for direction in self.graph[key].edges.keys(): + root_tree_node.childs[get_action(agent.direction, direction)] = \ + self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) + else: # Minor node + direction = first(self.get_possible_transitions(node.position, agent.direction)) + root_tree_node.childs['F'] = self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) + + else: # Just create a single child in the forward direction + prev_node, _, direction, distance = first(self.edge_positions[key]) + root_tree_node.childs['F'] = self.get_tree_branch(agent, prev_node, direction, visited_cells, -distance, 1) + + self.env.dev_obs_dict[handle] = visited_cells + + return root_tree_node + + # Get the next tree node, starting from `node`, facing `orientation`, and moving in `direction`. + def get_tree_branch(self, agent, RailNode node, int direction, visited_cells, int total_distance, int depth): + visited_cells.add((*node.position, 0)) + next_node, distance = node.edges[direction] + + cdef int edge_length = 0 + cdef int max_malfunction_length = 0 + cdef int num_agents_same_direction = 0 + cdef int num_agents_other_direction = 0 + cdef int distance_to_minor_node = positive_infinity + cdef int distance_to_other_agent = positive_infinity + cdef int distance_to_own_target = positive_infinity + cdef int distance_to_other_target = positive_infinity + cdef float min_agent_speed = 1 + cdef int num_agent_departures = 0 + + cdef int orientation = 0 + cdef int dist = 0 + + cdef int tmp_dist = 0 + cdef int tmp1 = 0 # Speed/Malfunction Length + + cdef tuple key = tuple() + cdef tuple next_key = tuple() + + cdef list path = list() + + + # Skip ahead until we get to a major node, logging any agents on the tracks along the way + while True: + path = self.edge_paths.get((node.position, direction), []) + orientation = path[len(path) - 1][len(path[len(path) - 1]) - 1] if path else direction + dist = total_distance + edge_length + key = (*node.position, direction) + next_key = (*next_node.position, orientation) + + visited_cells.update(path) + visited_cells.add((*next_node.position, 0)) + + # Check for other agents on the junctions up ahead + if next_key in self.nodes_with_agents_going: + num_agents_same_direction += 1 + # distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) + min_agent_speed = min(min_agent_speed, self.nodes_with_agents_going[next_key]) + + if next_key in self.nodes_with_agents_coming: + num_agents_other_direction += 1 + distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) + + if next_key in self.nodes_with_departures: + num_agent_departures += 1 + if next_key in self.nodes_with_malfunctions: + max_malfunction_length = max(max_malfunction_length, self.nodes_with_malfunctions[next_key]) + + # Check for other agents along the tracks up ahead + for tmp_dist, tmp1 in self.edges_with_agents_going[key].values(): + if dist + tmp_dist > 0: + num_agents_same_direction += 1 + min_agent_speed = min(min_agent_speed, tmp1) + # distance_to_other_agent = min(distance_to_other_agent, edge_length + d) + + for tmp_dist, _ in self.edges_with_agents_coming[key].values(): + if dist + tmp_dist > 0: + num_agents_other_direction += 1 + distance_to_other_agent = min(distance_to_other_agent, edge_length + tmp_dist) + + for tmp_dist in self.edges_with_departures[key].values(): + if dist + tmp_dist > 0: + num_agent_departures += 1 + + for tmp_dist, tmp1 in self.edges_with_malfunctions[key].values(): + if dist + tmp_dist > 0: + max_malfunction_length = max(max_malfunction_length, tmp1) + + # Check for target nodes up ahead + if next_node.is_target: + if self.is_own_target(agent, next_node): + distance_to_own_target = min(distance_to_own_target, edge_length + distance) + else: + distance_to_other_target = min(distance_to_other_target, edge_length + distance) + + # Move on to the next node + node = next_node + edge_length += distance + + if len(node.edges) == 1 and not self.is_own_target(agent, node): # This is a minor node, keep exploring + direction, (next_node, distance) = first(node.edges.items()) + if not node.is_target: + distance_to_minor_node = min(distance_to_minor_node, edge_length) + else: + break + + # Create a new tree node and populate its children + cdef dict children = {} + cdef str x = '' + if depth < self.max_depth: + for x in ACTIONS: + children[x] = negative_infinity + if not self.is_own_target(agent, node): + for direction in node.edges.keys(): + children[get_action(orientation, direction)] = \ + self.get_tree_branch(agent, node, direction, visited_cells, total_distance + edge_length, + depth + 1) + + return Node(dist_own_target_encountered=total_distance + distance_to_own_target, + dist_other_target_encountered=total_distance + distance_to_other_target, + dist_other_agent_encountered=total_distance + distance_to_other_agent, + dist_potential_conflict=positive_infinity, + dist_unusable_switch=total_distance + distance_to_minor_node, + dist_to_next_branch=total_distance + edge_length, + dist_min_to_target=self.env.distance_map.get()[(agent.handle, *node.position, orientation)] or 0, + num_agents_same_direction=num_agents_same_direction, + num_agents_opposite_direction=num_agents_other_direction, + num_agents_malfunctioning=max_malfunction_length, + speed_min_fractional=min_agent_speed, + num_agents_ready_to_depart=num_agent_departures, + childs=children) + + # Helper functions + + def get_possible_transitions(self, tuple position, int direction): + return [i for i, allowed in enumerate(self.env.rail.get_transitions(*position, direction)) if allowed] + + def get_all_transitions(self, tuple position): + return tuple(tuple(i for i, allowed in enumerate(bits) if allowed == '1') + for bits in f'{self.env.rail.get_full_transitions(*position):019_b}'.split("_")) + + def is_junction(self, tuple position): + return any(map(_check_len1, self.get_all_transitions(position))) + + def is_target(self, tuple position): + return position in self.target_positions + + def is_own_target(self, agent, RailNode node): + return agent.target == node.position + + +ZERO_NODE = torch.zeros((1, 11)) + +# Recursively create numpy arrays for each tree node +cpdef create_tree_features(node, int max_depth, list data): + cdef list nodes = [(node, 0)] + cdef int current_depth = 0 + for node, current_depth in nodes: + if node == negative_infinity or node == positive_infinity or node is None: + data.append(ZERO_NODE.expand((4 ** (max_depth - current_depth + 1) - 1) // 3, -1)) + else: + data.append(torch.FloatTensor(node[:11]).unsqueeze(0)) + if node.childs: + for direction in ACTIONS: + nodes.append((node.childs[direction], current_depth + 1)) + +# Normalize a tree observation +cpdef normalize_observation(tuple observations, int max_depth, shared_tensor, int starting_index): + cdef list data = [[[] for _ in range(len(observations[0]))] for _ in range(len(observations))] + cdef int i = 0 + cdef int sub = 0 + for i, tree in enumerate(observations, 1): + if tree is None: + break + if isinstance(tree, dict): + tree = tree.values() + for sub, t in enumerate(tree): + if isinstance(t, dict): + for d in t.values(): + create_tree_features(d, max_depth, data[i][sub]) + else: + create_tree_features(t, max_depth, data[i][sub]) + + + shared_tensor[starting_index:starting_index + i] = torch.stack([torch.stack([torch.cat(dat, 0) + for dat in tree if dat != []], -1) + for tree in data if tree != []], 0) diff --git a/src/ppo/__init__.py b/src/ppo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ppo/agent.py b/src/ppo/agent.py deleted file mode 100644 index 3f7071c..0000000 --- a/src/ppo/agent.py +++ /dev/null @@ -1,103 +0,0 @@ -import pickle -import random -import numpy as np -import torch -from torch.distributions.categorical import Categorical - -from ppo.model import PolicyNetwork -from replay_memory import Episode, ReplayBuffer - -BUFFER_SIZE = 32_000 -BATCH_SIZE = 4096 -GAMMA = 0.8 -LR = 0.5e-4 -CLIP_FACTOR = .005 -UPDATE_EVERY = 120 - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -class Agent: - def __init__(self, state_size, action_size, num_agents): - self.policy = PolicyNetwork(state_size, action_size).to(device) - self.old_policy = PolicyNetwork(state_size, action_size).to(device) - self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR) - - self.episodes = [Episode() for _ in range(num_agents)] - self.memory = ReplayBuffer(BUFFER_SIZE) - self.t_step = 0 - - def reset(self): - self.finished = [False] * len(self.episodes) - - - # Decide on an action to take in the environment - - def act(self, state, eps=None): - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - return Categorical(output).sample().item() - - - # Record the results of the agent's action and update the model - - def step(self, handle, state, action, next_state, agent_done, episode_done, collision): - if not self.finished[handle]: - if agent_done: - reward = 1 - elif collision: - reward = -.5 - else: reward = 0 - - # Push experience into Episode memory - self.episodes[handle].push(state, action, reward, next_state, agent_done or episode_done) - - # When we finish the episode, discount rewards and push the experience into replay memory - if agent_done or episode_done: - self.episodes[handle].discount_rewards(GAMMA) - self.memory.push_episode(self.episodes[handle]) - self.episodes[handle].reset() - self.finished[handle] = True - - # Perform a gradient update every UPDATE_EVERY time steps - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4: - self.learn(*self.memory.sample(BATCH_SIZE, device)) - - def learn(self, states, actions, rewards, next_state, done): - self.policy.train() - - responsible_outputs = torch.gather(self.policy(states), 1, actions) - old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() - - # rewards = rewards - rewards.mean() - ratio = responsible_outputs / (old_responsible_outputs + 1e-5) - clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() - - # Compute loss and perform a gradient step - self.old_policy.load_state_dict(self.policy.state_dict()) - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - - # Checkpointing methods - - def save(self, path, *data): - torch.save(self.policy.state_dict(), path / 'ppo/model_checkpoint.policy') - torch.save(self.optimizer.state_dict(), path / 'ppo/model_checkpoint.optimizer') - with open(path / 'ppo/model_checkpoint.meta', 'wb') as file: - pickle.dump(data, file) - - def load(self, path, *defaults): - try: - print("Loading model from checkpoint...") - self.policy.load_state_dict(torch.load(path / 'ppo/model_checkpoint.policy')) - self.optimizer.load_state_dict(torch.load(path / 'ppo/model_checkpoint.optimizer')) - with open(path / 'ppo/model_checkpoint.meta', 'rb') as file: - return pickle.load(file) - except: - print("No checkpoint file was found") - return defaults diff --git a/src/ppo/model.py b/src/ppo/model.py deleted file mode 100644 index 51febc2..0000000 --- a/src/ppo/model.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class PolicyNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=64, hidsize3=32): - super().__init__() - self.fc1 = nn.Linear(state_size, hidsize1) - self.fc2 = nn.Linear(hidsize1, hidsize2) - # self.fc3 = nn.Linear(hidsize2, hidsize3) - self.output = nn.Linear(hidsize2, action_size) - self.softmax = nn.Softmax(dim=1) - self.bn0 = nn.BatchNorm1d(state_size, affine=False) - - def forward(self, inputs): - x = self.bn0(inputs.float()) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - # x = F.relu(self.fc3(x)) - return self.softmax(self.output(x)) diff --git a/src/rail_env.pyx b/src/rail_env.pyx new file mode 100644 index 0000000..abaeac4 --- /dev/null +++ b/src/rail_env.pyx @@ -0,0 +1,810 @@ +#!python +#cython: boundscheck=False +#cython: initializedcheck=False +#cython: nonecheck=False +#cython: wraparound=False +""" +Definition of the RailEnv environment. +""" +import random +# TODO: _ this is a global method --> utils or remove later +from typing import List, NamedTuple, Dict + +import msgpack_numpy as m +import numpy as np +from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4 import Grid4Transitions +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import IntVector2D +# Need to use circular imports for persistence. +from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import persistence +from flatland.envs import rail_generators as rail_gen +from flatland.envs import schedule_generators as sched_gen +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.distance_map import DistanceMap +from flatland.envs.observations import GlobalObsForRailEnv +from gym.utils import seeding + +# Direct import of objects / classes does not work with circular imports. +# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData +# from flatland.envs.observations import GlobalObsForRailEnv +# from flatland.envs.rail_generators import random_rail_generator, RailGenerator +# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator + + +m.patch() + +cdef int DO_NOTHING = 0 # implies change of direction in a dead-end! +cdef int MOVE_LEFT = 1 +cdef int MOVE_FORWARD = 2 +cdef int MOVE_RIGHT = 3 +cdef int STOP_MOVING = 4 +cdef int ACTION_COUNT = 5 + +cpdef str to_char(a: int): + return { + 0: 'B', + 1: 'L', + 2: 'F', + 3: 'R', + 4: 'S', + }[a] + +RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) + + +class RailEnv(Environment): + """ + RailEnv environment class. + + RailEnv is an environment inspired by a (simplified version of) a rail + network, in which agents (trains) have to navigate to their target + locations in the shortest time possible, while at the same time cooperating + to avoid bottlenecks. + + The valid actions in the environment are: + + - 0: do nothing (continue moving or stay still) + - 1: turn left at switch and move to the next cell; if the agent was not moving, movement is started + - 2: move to the next cell in front of the agent; if the agent was not moving, movement is started + - 3: turn right at switch and move to the next cell; if the agent was not moving, movement is started + - 4: stop moving + + Moving forward in a dead-end cell makes the agent turn 180 degrees and step + to the cell it came from. + + + The actions of the agents are executed in order of their handle to prevent + deadlocks and to allow them to learn relative priorities. + + Reward Function: + + It costs each agent a step_penalty for every time-step taken in the environment. Independent of the movement + of the agent. Currently all other penalties such as penalty for stopping, starting and invalid actions are set to 0. + + alpha = 1 + beta = 1 + Reward function parameters: + + - invalid_action_penalty = 0 + - step_penalty = -alpha + - global_reward = beta + - epsilon = avoid rounding errors + - stop_penalty = 0 # penalty for stopping a moving agent + - start_penalty = 0 # penalty for starting a stopped agent + + Stochastic malfunctioning of trains: + Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid + action or cell is selected. + + Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a + poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep + complexity managable. + + TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init(). + For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility. + + """ + alpha = 1.0 + beta = 1.0 + # Epsilon to avoid rounding errors + epsilon = 0.01 + invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty + step_penalty = -1 * alpha + global_reward = 1 * beta + stop_penalty = 0 # penalty for stopping a moving agent + start_penalty = 0 # penalty for starting a stopped agent + + def __init__(self, + int width, + int height, + rail_generator = None, + schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), + number_of_agents=1, + obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), + malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), + remove_agents_at_target=True, + random_seed=1, + record_steps=False + ): + """ + Environment init. + + Parameters + ---------- + rail_generator : function + The rail_generator function is a function that takes the width, + height and agents handles of a rail environment, along with the number of times + the env has been reset, and returns a GridTransitionMap object and a list of + starting positions, targets, and initial orientations for agent handle. + The rail_generator can pass a distance map in the hints or information for specific schedule_generators. + Implementations can be found in flatland/envs/rail_generators.py + schedule_generator : function + The schedule_generator function is a function that takes the grid, the number of agents and optional hints + and returns a list of starting positions, targets, initial orientations and speed for all agent handles. + Implementations can be found in flatland/envs/schedule_generators.py + width : int + The width of the rail map. Potentially in the future, + a range of widths to sample from. + height : int + The height of the rail map. Potentially in the future, + a range of heights to sample from. + number_of_agents : int + Number of agents to spawn on the map. Potentially in the future, + a range of number of agents to sample from. + obs_builder_object: ObservationBuilder object + ObservationBuilder-derived object that takes builds observation + vectors for each agent. + remove_agents_at_target : bool + If remove_agents_at_target is set to true then the agents will be removed by placing to + RailEnv.DEPOT_POSITION when the agent has reach it's target position. + random_seed : int or None + if None, then its ignored, else the random generators are seeded with this number to ensure + that stochastic operations are replicable across multiple operations + """ + super().__init__() + + if malfunction_generator_and_process_data is None: + malfunction_generator_and_process_data = mal_gen.no_malfunction_generator() + self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data + #self.rail_generator: RailGenerator = rail_generator + if rail_generator is None: + rail_generator = rail_gen.random_rail_generator() + self.rail_generator = rail_generator + #self.schedule_generator: ScheduleGenerator = schedule_generator + if schedule_generator is None: + schedule_generator = sched_gen.random_schedule_generator() + self.schedule_generator = schedule_generator + + self.rail = None + self.width = width + self.height = height + + self.remove_agents_at_target = remove_agents_at_target + + self.rewards = [0] * number_of_agents + self.done = False + self.obs_builder = obs_builder_object + self.obs_builder.set_env(self) + + self._max_episode_steps = 0 + self._elapsed_steps = 0 + + self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) + + self.obs_dict = {} + self.rewards_dict = {} + self.dev_obs_dict = {} + self.dev_pred_dict = {} + + self.agents = [] + self.number_of_agents = number_of_agents + self.num_resets = 0 + self.distance_map = DistanceMap(self.agents, self.height, self.width) + + self.action_space = [5] + + self.random_seed = random_seed + self.np_random, seed = seeding.np_random(random_seed) + random.seed(seed) + + self.valid_positions = None + + # global numpy array of agents position, True means that there is an agent at that cell + self.agent_positions = np.full((height, width), False) + + # save episode timesteps ie agent positions, orientations. (not yet actions / observations) + self.record_steps = record_steps # whether to save timesteps + # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps] + self.cur_episode = [] + self.list_actions = [] # save actions in here + + def get_num_agents(self) -> int: + return len(self.agents) + + def reset_agents(self): + """ Reset the agents to their starting positions + """ + for agent in self.agents: + agent.reset() + self.active_agents = [i for i in range(len(self.agents))] + + def action_required(self, agent): + """ + Check if an agent needs to provide an action + + Parameters + ---------- + agent: RailEnvAgent + Agent we want to check + + Returns + ------- + True: Agent needs to provide an action + False: Agent cannot provide an action + """ + + return (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0)) + def reset(self) -> (Dict, Dict): + """ + reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) + + The method resets the rail environment + + Parameters + ---------- + regenerate_rail : bool, optional + regenerate the rails + regenerate_schedule : bool, optional + regenerate the schedule and the static agents + activate_agents : bool, optional + activate the agents + random_seed : bool, optional + random seed for environment + + Returns + ------- + observation_dict: Dict + Dictionary with an observation for each agent + info_dict: Dict with agent specific information + + """ + + cdef dict optionals = {} + rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets, + self.np_random) + + self.rail = rail + self.height, self.width = self.rail.grid.shape + + # Do a new set_env call on the obs_builder to ensure + # that obs_builder specific instantiations are made according to the + # specifications of the current environment : like width, height, etc + self.obs_builder.set_env(self) + + if optionals and 'distance_map' in optionals: + self.distance_map.set(optionals['distance_map']) + + agents_hints = None + if optionals and 'agents_hints' in optionals: + agents_hints = optionals['agents_hints'] + + schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets, + self.np_random) + self.agents = EnvAgent.from_schedule(schedule) + + # Get max number of allowed time steps from schedule generator + # Look at the specific schedule generator used to see where this number comes from + self._max_episode_steps = schedule.max_episode_steps + + self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1 + + # Reset agents to initial + self.reset_agents() + + for agent in self.agents: + + self._break_agent(agent) + + if agent.malfunction_data["malfunction"] > 0: + agent.speed_data['transition_action_on_cellexit'] = DO_NOTHING + + # Fix agents that finished their malfunction + self._fix_agent_after_malfunction(agent) + + self.num_resets += 1 + self._elapsed_steps = 0 + + # TODO perhaps dones should be part of each agent. + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + # Reset the state of the observation builder with the new environment + self.obs_builder.reset() + self.distance_map.reset(self.agents, self.rail) + + # Reset the malfunction generator + self.malfunction_generator(reset=True) + + # Empty the episode store of agent positions + self.cur_episode = [] + + cdef dict info_dict = { + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, + 'malfunction': { + i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) + }, + 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, + 'status': {i: agent.status for i, agent in enumerate(self.agents)} + } + # Return the new observation vectors for each agent + return self._get_observations(), info_dict + + def _fix_agent_after_malfunction(self, agent: EnvAgent): + """ + Updates agent malfunction variables and fixes broken agents + + Parameters + ---------- + agent + """ + + # Ignore agents that are OK + if self._is_agent_ok(agent): + return + + # Reduce number of malfunction steps left + if agent.malfunction_data['malfunction'] > 1: + agent.malfunction_data['malfunction'] -= 1 + return + + # Restart agents at the end of their malfunction + agent.malfunction_data['malfunction'] -= 1 + if 'moving_before_malfunction' in agent.malfunction_data: + agent.moving = agent.malfunction_data['moving_before_malfunction'] + return + + def _break_agent(self, agent: EnvAgent): + """ + Malfunction generator that breaks agents at a given rate. + + Parameters + ---------- + agent + + """ + + malfunction: Malfunction = self.malfunction_generator(agent, self.np_random) + if malfunction.num_broken_steps > 0: + agent.malfunction_data['malfunction'] = malfunction.num_broken_steps + agent.malfunction_data['moving_before_malfunction'] = agent.moving + agent.malfunction_data['nr_malfunctions'] += 1 + + return + + def step(self, dict action_dict_): + """ + Updates rewards for the agents at a step. + + """ + self._elapsed_steps += 1 + + # If we're done, set reward and info_dict and step() is done. + cdef dict info_dict = {} + cdef int i_agent = 0 + + if self.dones["__all__"]: + self.rewards_dict = {} + info_dict = {"action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + + for i_agent, agent in enumerate(self.agents): + self.rewards_dict[i_agent] = self.global_reward + info_dict["action_required"][i_agent] = False + info_dict["malfunction"][i_agent] = 0 + info_dict["speed"][i_agent] = 0 + info_dict["status"][i_agent] = agent.status + + return self._get_observations(), self.rewards_dict, self.dones, info_dict + + # Reset the step rewards + self.rewards_dict = dict() + info_dict = {"action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + cdef bint have_all_agents_ended = True # boolean flag to check if all agents are done + + for i_agent, agent in enumerate(self.agents): + # Reset the step rewards + self.rewards_dict[i_agent] = 0 + + # Induce malfunction before we do a step, thus a broken agent can't move in this step + self._break_agent(agent) + + # Perform step on the agent + self._step_agent(i_agent, action_dict_.get(i_agent)) + + # manage the boolean flag to check if all agents are indeed done (or done_removed) + have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) + + # Build info dict + info_dict["action_required"][i_agent] = self.action_required(agent) + info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] + info_dict["speed"][i_agent] = agent.speed_data['speed'] + info_dict["status"][i_agent] = agent.status + + # Fix agents that finished their malfunction such that they can perform an action in the next step + self._fix_agent_after_malfunction(agent) + + # Check for end of episode + set global reward to all rewards! + if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): + self.dones["__all__"] = True + for i_agent in range(self.get_num_agents()): + self.dones[i_agent] = True + if have_all_agents_ended: + self.dones["__all__"] = True + self.rewards_dict = {i: self.global_reward * (1 - self.dones[i]) for i in range(self.get_num_agents())} + else: + for i_agent in range(self.get_num_agents()): + if self.dones[i_agent]: + self.rewards_dict[i_agent] = 0 + if self.record_steps: + self.record_timestep(action_dict_) + + return self._get_observations(), self.rewards_dict, self.dones, info_dict + + + def _step_agent(self, int i_agent, int action): + """ + Performs a step and step, start and stop penalty on a single agent in the following sub steps: + - malfunction + - action handling if at the beginning of cell + - movement + + Parameters + ---------- + i_agent : int + + """ + agent = self.agents[i_agent] + if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed... + return + + # agent gets active by a MOVE_* action and if c + if agent.status == RailAgentStatus.READY_TO_DEPART: + if action in [MOVE_LEFT, MOVE_RIGHT, MOVE_FORWARD] and self.cell_free(agent.initial_position): + agent.status = RailAgentStatus.ACTIVE + self._set_agent_to_initial_position(agent, agent.initial_position) + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + else: + # TODO: Here we need to check for the departure time in future releases with full schedules + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + + agent.old_direction = agent.direction + agent.old_position = agent.position + + # if agent is broken, actions are ignored and agent does not move. + # full step penalty in this case + if agent.malfunction_data['malfunction'] > 0: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + return + + # Is the agent at the beginning of the cell? Then, it can take an action. + # As long as the agent is malfunctioning or stopped at the beginning of the cell, + # different actions may be taken! + if agent.speed_data['position_fraction'] == 0: + # No action has been supplied for this agent -> set DO_NOTHING as default + if action is None: + action = DO_NOTHING + + if action < 0 or action > ACTION_COUNT: # + print('ERROR: illegal action=', action, + 'for agent with index=', i_agent, + '"DO NOTHING" will be executed instead') + action = DO_NOTHING + + if action == DO_NOTHING and agent.moving: + # Keep moving + action = MOVE_FORWARD + + if action == STOP_MOVING and agent.moving: + # Only allow halting an agent on entering new cells. + agent.moving = False + self.rewards_dict[i_agent] += self.stop_penalty + + if not agent.moving and not ( + action == DO_NOTHING or + action == STOP_MOVING): + # Allow agent to start with any forward or direction action + agent.moving = True + self.rewards_dict[i_agent] += self.start_penalty + + # Store the action if action is moving + # If not moving, the action will be stored when the agent starts moving again. + if agent.moving: + _action_stored = False + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(action, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = action + _action_stored = True + else: + # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, + # try to keep moving forward! + if action in (MOVE_LEFT, MOVE_RIGHT): + _, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(MOVE_FORWARD, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = MOVE_FORWARD + _action_stored = True + + if not _action_stored: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False + + # Now perform a movement. + # If agent.moving, increment the position_fraction by the speed of the agent + # If the new position fraction is >= 1, reset to 0, and perform the stored + # transition_action_on_cellexit if the cell is free. + if agent.moving: + agent.speed_data['position_fraction'] += agent.speed_data['speed'] + if agent.speed_data['position_fraction'] > 1.0-1e-3: + # Perform stored action to transition to the next cell as soon as cell is free + # Notice that we've already checked new_cell_valid and transition valid when we stored the action, + # so we only have to check cell_free now! + + # cell and transition validity was checked when we stored transition_action_on_cellexit! + cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( + agent.speed_data['transition_action_on_cellexit'], agent) + + # N.B. validity of new_cell and transition should have been verified before the action was stored! + assert new_cell_valid + assert transition_valid + if cell_free: + self._move_agent_to_new_position(agent, new_position) + agent.direction = new_direction + agent.speed_data['position_fraction'] = 0.0 + + # has the agent reached its target? + if np.equal(agent.position, agent.target).all(): + agent.status = RailAgentStatus.DONE + self.dones[i_agent] = True + self.active_agents.remove(i_agent) + agent.moving = False + self._remove_agent_from_scene(agent) + else: + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + else: + # step penalty if not moving (stopped now or before) + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + + def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Sets the agent to its initial position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.position] = agent.handle + + def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Move the agent to the a new position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.old_position] = -1 + self.agent_positions[agent.position] = agent.handle + + def _remove_agent_from_scene(self, agent: EnvAgent): + """ + Remove the agent from the scene. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + """ + self.agent_positions[agent.position] = -1 + if self.remove_agents_at_target: + agent.position = None + agent.status = RailAgentStatus.DONE_REMOVED + + def _check_action_on_agent(self, int action, agent: EnvAgent): + """ + + Parameters + ---------- + action : + agent : EnvAgent + + Returns + ------- + bool + Is it a legal move? + 1) transition allows the new_direction in the cell, + 2) the new cell is not empty (case 0), + 3) the cell is free, i.e., no agent is currently in that cell + + + """ + # compute number of possible transitions in the current + # cell used to check for invalid actions + cdef tuple act_chk = self.check_action(agent, action) + cdef int new_direction = act_chk[0] + transition_valid = act_chk[1] + cdef bint cell_free = False + cdef tuple new_position = get_new_position(agent.position, new_direction) + + cdef bint new_cell_valid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_full_transitions(*new_position) > 0) + if transition_valid is None: + transition_valid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) + # only call cell_free() if new cell is inside the scene + if new_cell_valid: + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_free = self.cell_free(new_position) + + return cell_free, new_cell_valid, new_direction, new_position, transition_valid + + def record_timestep(self, dActions): + ''' Record the positions and orientations of all agents in memory, in the cur_episode + ''' + cdef list list_agents_state = [] + cdef int i_agent = 0 + cdef tuple pos = tuple() + for i_agent in range(self.get_num_agents()): + agent = self.agents[i_agent] + # the int cast is to avoid numpy types which may cause problems with msgpack + # in env v2, agents may have position None, before starting + if agent.position is None: + pos = (0, 0) + else: + pos = (int(agent.position[0]), int(agent.position[1])) + # print("pos:", pos, type(pos[0])) + list_agents_state.append( + [*pos, int(agent.direction), agent.malfunction_data["malfunction"]]) + + self.cur_episode.append(list_agents_state) + self.list_actions.append(dActions) + + def cell_free(self, position: IntVector2D) -> bool: + """ + Utility to check if a cell is free + + Parameters: + -------- + position : Tuple[int, int] + + Returns + ------- + bool + is the cell free or not? + + """ + return self.agent_positions[position] == -1 + + def check_action(self, agent: EnvAgent, action): + """ + + Parameters + ---------- + agent : EnvAgent + action : + + Returns + ------- + Tuple[Grid4TransitionsEnum,Tuple[int,int]] + + + + """ + transition_valid = None + cdef tuple possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) + cdef int num_transitions = np.count_nonzero(possible_transitions) + cdef int new_direction = agent.direction + + if action == MOVE_LEFT: + new_direction = agent.direction - 1 + if num_transitions <= 1: + transition_valid = False + + elif action == MOVE_RIGHT: + new_direction = agent.direction + 1 + if num_transitions <= 1: + transition_valid = False + + new_direction %= 4 + + if action == MOVE_FORWARD and num_transitions == 1: + # - dead-end, straight line or curved line; + # new_direction will be the only valid transition + # - take only available transition + new_direction = np.argmax(possible_transitions) + transition_valid = True + return new_direction, transition_valid + + def _get_observations(self): + """ + Utility which returns the observations for an agent with respect to environment + + Returns + ------ + Dict object + """ + #print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") + self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) + return self.obs_dict + + def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: + """ + Returns directions in which the agent can move + + Parameters: + --------- + row : int + col : int + + Returns: + ------- + List[int] + """ + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) + + def _exp_distirbution_synced(self, rate: float) -> float: + """ + Generates sample from exponential distribution + We need this to guarantee synchronity between different instances with same seed. + :param rate: + :return: + """ + cdef float u = self.np_random.rand() + cdef float x = - np.log(1 - u) * rate + return x + + def _is_agent_ok(self, agent: EnvAgent) -> bool: + """ + Check if an agent is ok, meaning it can move and is not malfuncitoinig + Parameters + ---------- + agent + + Returns + ------- + True if agent is ok, False otherwise + + """ + return agent.malfunction_data['malfunction'] < 1 + + def save(self, filename): + print("deprecated call to env.save() - pls call RailEnvPersister.save()") + persistence.RailEnvPersister.save(self, filename) diff --git a/src/railway_utils.py b/src/railway_utils.py index 6362b9d..1cf0216 100644 --- a/src/railway_utils.py +++ b/src/railway_utils.py @@ -1,30 +1,108 @@ +import os import pickle +import numpy as np + from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator +try: + from .agent import BATCH_SIZE +except: + from agent import BATCH_SIZE + + +class Generator: + + def __init__(self, path, start_index=0): + self.path = path + self.index = start_index + self.data = iter([]) + self.len = 0 + + self._load() + + def _load(self): + with open(self.path, 'rb') as file: + data = pickle.load(file) + self.len = len(data) + self.data = iter(data[self.index:]) + + def __next__(self): + try: + data = next(self.data) + except StopIteration: + self._load() + if self.index >= self.len: + print("[WARNING] Restarting training loop from zero") + self.index = 0 + self._load() + data = next(self) + self.index += 1 + return data + + def __len__(self): + return self.len + + def __call__(self, *args, **kwargs): + return next(self) + + +class RailGenerator: + def __init__(self, width=35, base=1.5, factor=2): + self.rail_generator = sparse_rail_generator(grid_mode=False, + max_num_cities=max(2, width ** 2 // 300), + max_rails_between_cities=2, + max_rails_in_city=3) + self.sub_idx = 0 + self.top_idx = 0 + self.width = width + self.base = base + self.factor = factor + + def __next__(self): + self.sub_idx += 1 + if self.sub_idx == BATCH_SIZE: + self.sub_idx = 0 + self.top_idx += 1 + return self.rail_generator(self.width, self.width, int(self.factor * self.base ** self.top_idx), np_random=np.random) + + def __call__(self, *args, **kwargs): + return next(self) + + +class ScheduleGenerator: + def __init__(self, base=1.5, factor=2): + self.schedule_generator = sparse_schedule_generator({1.: 1.}) + self.sub_idx = 0 + self.top_idx = 0 + self.base = base + self.factor = factor + + def __next__(self, rail, hints): + if self.sub_idx == BATCH_SIZE: + self.sub_idx = 0 + self.top_idx += 1 + self.sub_idx += 1 + return self.schedule_generator(rail, int(self.factor * self.base ** self.top_idx), hints, np_random=np.random) + + def __call__(self, rail, _, hints, *args, **kwargs): + return self.__next__(rail, hints) + # Helper function to load in precomputed railway networks -def load_precomputed_railways(project_root, flags): - with open(project_root / f'railroads/rail_networks_{flags.num_agents}x{flags.grid_width}x{flags.grid_height}.pkl', 'rb') as file: - data = pickle.load(file) - rail_networks = iter(data) - print(f"Loading {len(data)} railways...") - with open(project_root / f'railroads/schedules_{flags.num_agents}x{flags.grid_width}x{flags.grid_height}.pkl', 'rb') as file: - schedules = iter(pickle.load(file)) - - rail_generator = lambda *args: next(rail_networks) - schedule_generator = lambda *args: next(schedules) - return rail_generator, schedule_generator +def load_precomputed_railways(project_root, start_index, big=False): + prefix = os.path.join(project_root, 'railroads') + if big: + suffix = f'_35_6.pkl' # base=1.1 + else: + suffix = f'_35_4.pkl' # base=1.04 + rail = Generator(os.path.join(prefix, 'rail_networks' + suffix), start_index) + sched = Generator(os.path.join(prefix, 'schedules' + suffix), start_index) + print(f"Working on {len(rail)} tracks") + return rail, sched + # Helper function to generate railways on the fly -def create_random_railways(project_root): - speed_ration_map = { - 1 / 1: 1.0, # Fast passenger train - 1 / 2.: 0.0, # Fast freight train - 1 / 3.: 0.0, # Slow commuter train - 1 / 4.: 0.0 } # Slow freight train - - rail_generator = sparse_rail_generator(grid_mode=False, max_num_cities=3, max_rails_between_cities=2, max_rails_in_city=3) - schedule_generator = sparse_schedule_generator(speed_ration_map) - return rail_generator, schedule_generator +def create_random_railways(width, base=1.1, factor=2): + return RailGenerator(width=width, base=base, factor=factor), ScheduleGenerator(base=base, factor=factor) diff --git a/src/replay_memory.py b/src/replay_memory.py deleted file mode 100644 index 61a1b81..0000000 --- a/src/replay_memory.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -import random -import numpy as np -from collections import namedtuple, deque, Iterable - - -Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) - - -class Episode: - memory = [] - - def reset(self): - self.memory = [] - - def push(self, *args): - self.memory.append(tuple(args)) - - def discount_rewards(self, gamma): - running_add = 0. - for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]: - running_add = running_add * gamma + reward - self.memory[i] = (state, action, running_add, *rest) - - -class ReplayBuffer: - def __init__(self, buffer_size): - self.memory = deque(maxlen=buffer_size) - - def push(self, state, action, reward, next_state, done): - self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)) - - def push_episode(self, episode): - for step in episode.memory: - self.push(*step) - - def sample(self, batch_size, device): - experiences = random.sample(self.memory, k=batch_size) - - states = torch.from_numpy(self.stack([e.state for e in experiences])).float().to(device) - actions = torch.from_numpy(self.stack([e.action for e in experiences])).long().to(device) - rewards = torch.from_numpy(self.stack([e.reward for e in experiences])).float().to(device) - next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device) - dones = torch.from_numpy(self.stack([e.done for e in experiences]).astype(np.uint8)).float().to(device) - - return states, actions, rewards, next_states, dones - - def stack(self, states): - sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1] - return np.reshape(np.array(states), (len(states), *sub_dims)) - - def __len__(self): - return len(self.memory) diff --git a/src/train.py b/src/train.py index d52112b..f3218ff 100644 --- a/src/train.py +++ b/src/train.py @@ -1,212 +1,293 @@ -import cv2 -import time import argparse -import numpy as np +import copy +import time +from itertools import zip_longest from pathlib import Path -from collections import deque -from tensorboardX import SummaryWriter -from flatland.envs.rail_env import RailEnv, RailEnvActions +import numpy as np +import torch from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters -from flatland.utils.rendertools import RenderTool, AgentRenderVariant +# from flatland.utils.rendertools import RenderTool, AgentRenderVariant +from pathos import multiprocessing -from dqn.agent import Agent as DQN_Agent -from ppo.agent import Agent as PPO_Agent -from tree_observation import TreeObservation -from observation_utils import normalize_observation, is_collision -from railway_utils import load_precomputed_railways, create_random_railways +# import cv2 +torch.jit.optimized_execution(True) + +positive_infinity = int(1e5) +negative_infinity = -positive_infinity + +try: + from .rail_env import RailEnv + from .agent import Agent as DQN_Agent, device, BATCH_SIZE, UPDATE_EVERY + from .normalize_output_data import wrap + from .observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv, \ + GlobalStateObs + from .railway_utils import load_precomputed_railways, create_random_railways +except: + from rail_env import RailEnv + from agent import Agent as DQN_Agent, device, BATCH_SIZE, UPDATE_EVERY + from normalize_output_data import wrap + from observation_utils import normalize_observation, TreeObservation, GlobalObsForRailEnv, LocalObsForRailEnv, \ + GlobalStateObs + from railway_utils import load_precomputed_railways, create_random_railways + project_root = Path(__file__).resolve().parent.parent parser = argparse.ArgumentParser(description="Train an agent in the flatland environment") boolean = lambda x: str(x).lower() == 'true' # Task parameters parser.add_argument("--train", type=boolean, default=True, help="Whether to train the model or just evaluate it") -parser.add_argument("--load-model", default=False, action='store_true', help="Whether to load the model from the last checkpoint") -parser.add_argument("--load-railways", type=boolean, default=True, help="Whether to load in pre-generated railway networks") -parser.add_argument("--report-interval", type=int, default=100, help="Iterations between reports") +parser.add_argument("--load-model", default=False, action='store_true', + help="Whether to load the model from the last checkpoint") parser.add_argument("--render-interval", type=int, default=0, help="Iterations between renders") # Environment parameters -parser.add_argument("--grid-width", type=int, default=50, help="Number of columns in the environment grid") -parser.add_argument("--grid-height", type=int, default=50, help="Number of rows in the environment grid") -parser.add_argument("--num-agents", type=int, default=5, help="Number of agents in each episode") -parser.add_argument("--tree-depth", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--tree-depth", type=int, default=2, help="Depth of the observation tree") +parser.add_argument("--model-depth", type=int, default=5, help="Depth of the observation tree") +parser.add_argument("--hidden-factor", type=int, default=48, help="Depth of the observation tree") +parser.add_argument("--kernel-size", type=int, default=1, help="Depth of the observation tree") +parser.add_argument("--squeeze-heads", type=int, default=4, help="Depth of the observation tree") +parser.add_argument("--observation-size", type=int, default=4, help="Depth of the observation tree") +parser.add_argument("--decoder-depth", type=int, default=1, help="Depth of the observation tree") # Training parameters -parser.add_argument("--agent-type", default="ppo", choices=["dqn", "ppo"], help="Which type of RL agent to use") -parser.add_argument("--num-episodes", type=int, default=10000, help="Number of episodes to train for") -parser.add_argument("--epsilon-decay", type=float, default=0.997, help="Decay factor for epsilon-greedy exploration") +parser.add_argument("--step-reward", type=float, default=-1e-2, help="Depth of the observation tree") +parser.add_argument("--collision-reward", type=float, default=-2, help="Depth of the observation tree") +parser.add_argument("--global-environment", type=boolean, default=False, help="Depth of the observation tree") +parser.add_argument("--local-environment", type=boolean, default=False, help="Depth of the observation tree") +parser.add_argument("--state-environment", type=boolean, default=True, help="Depth of the observation tree") +parser.add_argument("--threads", type=int, default=1, help="Depth of the observation tree") flags = parser.parse_args() +if sum((flags.global_environment, flags.local_environment, flags.state_environment)) > 1: + print("Too many environment flags used. Priority is global > local > state.") + +if flags.global_environment: + model_type = 1 + env = GlobalObsForRailEnv() +elif flags.local_environment: + model_type = 1 + env = LocalObsForRailEnv(flags.observation_size) +elif flags.state_environment: + model_type = 2 + env = GlobalStateObs() +else: + model_type = 0 + env = TreeObservation(flags.tree_depth) + +if model_type not in (0, 1, 2): + raise UserWarning("Unknown model type") # Seeded RNG so we can replicate our results -np.random.seed(1) # Create a tensorboard SummaryWriter -summary = SummaryWriter(f'tensorboard/dqn/agents: {flags.num_agents}, tree_depth: {flags.tree_depth}') - -# We need to either load in some pre-generated railways from disk, or else create a random railway generator. -if flags.load_railways: - rail_generator, schedule_generator = load_precomputed_railways(project_root, flags) -else: rail_generator, schedule_generator = create_random_railways(project_root) - -# Create the Flatland environment -env = RailEnv(width=flags.grid_width, height=flags.grid_height, number_of_agents=flags.num_agents, - rail_generator=rail_generator, - schedule_generator=schedule_generator, - malfunction_generator_and_process_data=malfunction_from_params(MalfunctionParameters(1 / 8000, 15, 50)), - obs_builder_object=TreeObservation(max_depth=flags.tree_depth) -) - -# After training we want to render the results so we also load a renderer -env_renderer = RenderTool(env, gl="PILSVG", screen_width=800, screen_height=800, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) - # Calculate the state size based on the number of nodes in the tree observation -num_features_per_node = env.obs_builder.observation_dim -num_nodes = sum(np.power(4, i) for i in range(flags.tree_depth + 1)) +num_features_per_node = 11 # env.obs_builder.observation_dim +num_nodes = int('1' * (flags.tree_depth + 1), 4) state_size = num_nodes * num_features_per_node action_size = 5 - -# Add some variables to keep track of the progress -scores_window, steps_window, collisions_window, done_window = [deque(maxlen=200) for _ in range(4)] -agent_obs = [None] * flags.num_agents -agent_obs_buffer = [None] * flags.num_agents -agent_action_buffer = [2] * flags.num_agents -max_steps = 8 * (flags.grid_width + flags.grid_height) -start_time = time.time() - # Load an RL agent and initialize it from checkpoint if necessary -if flags.agent_type == "dqn": - agent = DQN_Agent(state_size, action_size, flags.num_agents) -elif flags.agent_type == "ppo": - agent = PPO_Agent(state_size, action_size, flags.num_agents) - +agent = DQN_Agent(state_size, + action_size, + flags.model_depth, + flags.hidden_factor, + flags.kernel_size, + flags.squeeze_heads, + flags.decoder_depth, + model_type) if flags.load_model: - start, eps = agent.load(project_root / 'checkpoints', 0, 1.0) -else: start, eps = 0, 1.0 + start, = agent.load(project_root / 'checkpoints', 0) +else: + start = 0 +# We need to either load in some pre-generated railways from disk, or else create a random railway generator. +rail_generator, schedule_generator = load_precomputed_railways(project_root, start * BATCH_SIZE) -if not flags.train: - eps = 0.0 +# Create the Flatland environment +environments = [RailEnv(width=50, height=50, number_of_agents=1, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + malfunction_generator_and_process_data=malfunction_from_params( + MalfunctionParameters(1 / 500, 20, 50)), + obs_builder_object=copy.deepcopy(env), + random_seed=i) + for i in range(BATCH_SIZE)] +env = environments[0] +# After training we want to render the results so we also load a renderer -# We don't want to retrain on old railway networks when we restart from a checkpoint, so we just loop -# through the generators to get all the old networks out of the way -if start > 0: print(f"Skipping {start} railways") -for _ in range(0, start): - rail_generator() - schedule_generator() +# Add some variables to keep track of the progress +agent_action_buffer = [] +start_time = time.time() # Helper function to detect collisions -ACTIONS = { 0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S' } - -def is_collision(a): - if obs[a] is None: return False - is_junction = not isinstance(obs[a].childs['L'], float) or not isinstance(obs[a].childs['R'], float) - - if not is_junction or env.agents[a].speed_data['position_fraction'] > 0: - action = ACTIONS[env.agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' - return obs[a].childs[action].num_agents_opposite_direction > 0 \ - and obs[a].childs[action].dist_other_agent_encountered <= 1 \ - and obs[a].childs[action].dist_other_agent_encountered < obs[a].childs[action].dist_unusable_switch - else: return False - -# Helper function to render the environment -def render(): - env_renderer.render_env(show_observations=False) - cv2.imshow('Render', cv2.cvtColor(env_renderer.get_image(), cv2.COLOR_BGR2RGB)) - cv2.waitKey(120) - -# Helper function to generate a report -def get_report(show_time=False): - training = 'Training' if flags.train else 'Evaluating' - return ' | '.join(filter(None, [ - f'\r{training} {flags.num_agents} Agents on {flags.grid_width} x {flags.grid_height} Map', - f'Episode {episode:<5}', - f'Average Score: {np.mean(scores_window):.3f}', - f'Average Steps Taken: {np.mean(steps_window):<6.1f}', - f'Collisions: {100 * np.mean(collisions_window):>5.2f}%', - f'Finished: {100 * np.mean(done_window):>6.2f}%', - f'Epsilon: {eps:.2f}' if flags.agent_type == "dqn" else None, - f'Time taken: {time.time() - start_time:.2f}s' if show_time else None])) + ' ' - - +ACTIONS = {0: 'B', 1: 'L', 2: 'F', 3: 'R', 4: 'S'} + +if model_type in (1, 2): + def is_collision(a, i): + own_agent = environments[i].agents[a] + return any(own_agent.position == agent.position + for agent_id, agent in enumerate(environments[i].agents) + if agent_id != a) +else: # model_type == 0 + def is_collision(a, i): + if obs[i][a] is None: return False + is_junction = not isinstance(obs[i][a].childs['L'], float) or not isinstance(obs[i][a].childs['R'], float) + + if not is_junction or environments[i].agents[a].speed_data['position_fraction'] > 0: + action = ACTIONS[ + environments[i].agents[a].speed_data['transition_action_on_cellexit']] if is_junction else 'F' + return obs[i][a].childs[action] != negative_infinity and obs[i][a].childs[action] != positive_infinity \ + and obs[i][a].childs[action].num_agents_opposite_direction > 0 \ + and obs[i][a].childs[action].dist_other_agent_encountered <= 1 \ + and obs[i][a].childs[action].dist_other_agent_encountered < obs[i][a].childs[ + action].dist_unusable_switch + else: + return False + +chunk_size = (BATCH_SIZE + 1) // flags.threads + + +def chunk(obj, size): + return zip_longest(*[iter(obj)] * size, fillvalue=None) + + +if flags.threads > 1: + def normalize(observation, target_tensor): + POOL.starmap(func=normalize_observation, + iterable=((o, flags.tree_depth, target_tensor, i * chunk_size) + for i, o in enumerate(chunk(observation, chunk_size)))) + wrap(target_tensor) +else: + def normalize(observation, target_tensor): + normalize_observation(observation, flags.tree_depth, target_tensor, 0) + wrap(target_tensor) + + +def as_tensor(array_list): + return torch.as_tensor(np.stack(array_list, 0), dtype=torch.float, device=device) + + +def make_tensor(current, old): + if model_type == 1: + tensor = (torch.cat([old[0], current[0]], -1), torch.zeros(1)) + elif model_type == 0: + tensor = (torch.cat([old[0].flatten(1, 2), current[0].flatten(1, 2)], 1), torch.zeros(1)) + else: # model_type == 2 + tensor = (torch.cat((old[0], current[0]), 1), torch.cat((old[1], current[1]), -1)) + tensor[1].transpose_(1, -1) + tensor = tuple(t.to(device) for t in tensor) + return tensor + + +def clone(tensor_tuple): + return tuple(t.clone().detach().requires_grad_(t.requires_grad) for t in tensor_tuple) + + +def get_observation_tensor(observation, prev_tensor=None): + buffer = None if prev_tensor is None else clone(prev_tensor) + if model_type == 1: + obs_tensor = as_tensor(observation) + elif model_type == 0: + obs_tensor = torch.zeros((BATCH_SIZE, state_size // 11, 11, agent_count)) + normalize(observation, obs_tensor) + else: # model_type == 2 + rail, obs = zip(*observation) + obs_tensor = as_tensor(rail), as_tensor(obs) + return obs_tensor, buffer + + +episode = 0 +POOL = multiprocessing.Pool() +# env_renderer = None +# def render(): +# env_renderer.render_env(show_observations=False) +# cv2.imshow('Render', cv2.cvtColor(env_renderer.get_image(), cv2.COLOR_BGR2RGB)) +# cv2.waitKey(120) # Main training loop -for episode in range(start + 1, flags.num_episodes + 1): +episode = start +running_stats = {'score': 0, 'steps': 0, 'collisions': 0, 'done': 0, 'finished': 0} +batch_start = time.time() +while True: + episode += 1 agent.reset() - env_renderer.reset() - obs, info = env.reset(True, True) - score, steps_taken, collision = 0, 0, False - - # Build initial observations for each agent - for a in range(flags.num_agents): - agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=flags.agent_type == 'dqn') - agent_obs_buffer[a] = agent_obs[a].copy() - + obs, info = zip(*[env.reset() for env in environments]) + # env_renderer = RenderTool(environments[0], gl="PILSVG", screen_width=1000, screen_height=1000, agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX) + # env_renderer.reset() + score, collision = 0, False + agent_count = len(environments[0].agents) + + agent_action_buffer = torch.zeros((BATCH_SIZE, 5, agent_count), device=device, dtype=torch.float) + agent_action_buffer[:, 2] += 1 + agent_obs, _ = get_observation_tensor(obs) + agent_obs_buffer = clone(agent_obs) + input_tensor = make_tensor(agent_obs, agent_obs_buffer) # Run an episode + city_count = (env.width * env.height) // 300 + max_steps = int(8 * (env.width + env.height + agent_count / city_count)) - 10 + # -10 = have some distance to the "real" max steps + done = [[False]] + _done = [[False]] + step = 0 for step in range(max_steps): - update_values = [False] * flags.num_agents - action_dict = {} + done = _done - for a in range(flags.num_agents): - if info['action_required'][a]: - action_dict[a] = agent.act(agent_obs[a], eps=eps) - # action_dict[a] = np.random.randint(5) - update_values[a] = True - steps_taken += 1 - else: action_dict[a] = 0 + mdl_action, ret_action = agent.multi_act(input_tensor, False) + action_dict = [dict(enumerate(act_list)) for act_list in ret_action] # Environment step - obs, rewards, done, info = env.step(action_dict) - score += sum(rewards.values()) / flags.num_agents + obs, rewards, _done, info = tuple(zip(*[e.step(a) for e, a in zip(environments, action_dict)])) + score += sum(i for r in rewards for i in r.values()) / (agent_count * BATCH_SIZE) # Check for collisions and episode completion - if step == max_steps - 1: - done['__all__'] = True - if any(is_collision(a) for a in obs): - collision = True - # done['__all__'] = True - + all_done = (step == (max_steps - 1)) or all(d['__all__'] for d in _done) + collision = [[is_collision(a, i) for a in range(agent_count)] for i in range(BATCH_SIZE)] # Update replay buffer and train agent - for a in range(flags.num_agents): - if flags.train and (update_values[a] or done[a] or done['__all__']): - agent.step(a, agent_obs_buffer[a], agent_action_buffer[a], agent_obs[a], done[a], done['__all__'], is_collision(a)) - agent_obs_buffer[a] = agent_obs[a].copy() - agent_action_buffer[a] = action_dict[a] - - if obs[a]: - agent_obs[a] = normalize_observation(obs[a], flags.tree_depth, zero_center=flags.agent_type == 'dqn') + agent_obs, agent_obs_buffer = get_observation_tensor(obs, agent_obs) + next_input = make_tensor(agent_obs, agent_obs_buffer) + + if flags.train: + agent.step(input_tensor, + agent_action_buffer, + _done, + collision, + next_input, + flags.step_reward, + flags.collision_reward) + agent_action_buffer = mdl_action + + if all_done: + break + input_tensor = next_input # Render # if flags.render_interval and episode % flags.render_interval == 0: # if collision and all(agent.position for agent in env.agents): - # render() + # if step % 2 == 1: + # print([a.position for a in environments[0].agents]) + # render() # print("Collisions detected by agent(s)", ', '.join(str(a) for a in obs if is_collision(a))) # break - if done['__all__']: break - - # Epsilon decay - if flags.train: eps = max(0.01, flags.epsilon_decay * eps) - - # Save some training statistics in their respective deques - tasks_finished = sum(done[i] for i in range(flags.num_agents)) - done_window.append(tasks_finished / max(1, flags.num_agents)) - collisions_window.append(1. if collision else 0.) - scores_window.append(score / max_steps) - steps_window.append(steps_taken) - - # Generate training reports, saving our progress every so often - print(get_report(), end=" ") - if episode % flags.report_interval == 0: - print(get_report(show_time=True)) - start_time = time.time() - if flags.train: agent.save(project_root / 'checkpoints', episode, eps) - - # Add stats to the tensorboard summary - summary.add_scalar('performance/avg_score', np.mean(scores_window), episode) - summary.add_scalar('performance/avg_steps', np.mean(steps_window), episode) - summary.add_scalar('performance/completions', np.mean(done_window), episode) - summary.add_scalar('performance/collisions', np.mean(collisions_window), episode) + running_stats['score'] += score / max_steps + running_stats['steps'] += step + running_stats['collisions'] += sum(i for c in collision for i in c) / agent_count + running_stats['done'] += sum(d[i] for d in done for i in range(agent_count)) / agent_count + running_stats['finished'] += sum(d["__all__"] for d in done) + + if episode % UPDATE_EVERY == 0: + running_stats = {k: v / UPDATE_EVERY for k, v in running_stats.items()} + print(f'\rBatch{episode:>3} - Episode{BATCH_SIZE * episode:>5} - Agents:{agent_count:>3}' + f' | Score: {running_stats["score"]:.4f}' + f' | Steps: {running_stats["steps"]:4.0f}' + f' | Collisions: {100 * running_stats["collisions"] / BATCH_SIZE:6.2f}%' + f' | Done: {100 * running_stats["done"] / BATCH_SIZE:6.2f}%' + f' | Finished: {100 * running_stats["finished"] / BATCH_SIZE:6.2f}%' + f' | Took: {time.time() - batch_start:5.0f}s') + running_stats = {k: 0 for k in running_stats.keys()} + batch_start = time.time() + +# if flags.train: +# agent.save(project_root / 'checkpoints', episode) diff --git a/src/tree_observation.py b/src/tree_observation.py deleted file mode 100644 index 91c6e9e..0000000 --- a/src/tree_observation.py +++ /dev/null @@ -1,326 +0,0 @@ -import numpy as np -from collections import defaultdict - -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.observations import TreeObsForRailEnv - - -ACTIONS = ['L', 'F', 'R', 'B'] -Node = TreeObsForRailEnv.Node - -def first(list): - return next(iter(list)) - -def get_action(orientation, direction): - return ACTIONS[(direction - orientation + 1 + 4) % 4] - -def get_direction(orientation, action): - if action == 1: - return (orientation + 4 - 1) % 4 - elif action == 3: - return (orientation + 1) % 4 - else: return orientation - - -class RailNode: - def __init__(self, position, edge_directions, is_target): - self.edges = {} - self.position = position - self.edge_directions = edge_directions - self.is_target = is_target - - def __repr__(self): - return f'RailNode({self.position}, {len(self.edges)})' - - - -class TreeObservation(ObservationBuilder): - def __init__(self, max_depth): - super().__init__() - self.max_depth = max_depth - self.observation_dim = 11 - - - # Create a graph representation of the current rail network - - def reset(self): - self.target_positions = { agent.target: 1 for agent in self.env.agents } - self.edge_positions = defaultdict(list) # (cell.position, direction) -> [(start, end, direction, distance)] - self.edge_paths = defaultdict(list) # (node.position, direction) -> [(cell.position, direction)] - - # First, we find a node by starting at one of the agents and following the rails until we reach a junction - agent = first(self.env.agents) - position = agent.initial_position - direction = agent.direction - while not self.is_junction(position) and not self.is_target(position): - direction = first(self.get_possible_transitions(position, direction)) - position = get_new_position(position, direction) - - # Now we create a graph representation of the rail network, starting from this node - transitions = self.get_all_transitions(position) - root_nodes = { t: RailNode(position, t, self.is_target(position)) for t in transitions if t } - self.graph = { (*position, d): root_nodes[t] for d, t in enumerate(transitions) if t } - - for transitions, node in root_nodes.items(): - for direction in transitions: - self.explore_branch(node, get_new_position(position, direction), direction) - - - def explore_branch(self, node, position, direction): - original_direction = direction - edge_positions = {} - distance = 1 - - # Explore until we find a junction - while not self.is_junction(position) and not self.is_target(position): - next_direction = first(self.get_possible_transitions(position, direction)) - edge_positions[(*position, direction)] = (distance, next_direction) - position = get_new_position(position, next_direction) - direction = next_direction - distance += 1 - - # Create any nodes that aren't in the graph yet - transitions = self.get_all_transitions(position) - nodes = { t: RailNode(position, t, self.is_target(position)) - for d, t in enumerate(transitions) - if t and (*position, d) not in self.graph } - - for d, t in enumerate(transitions): - if t in nodes: - self.graph[(*position, d)] = nodes[t] - - # Connect the previous node to the next one, and update self.edge_positions - next_node = self.graph[(*position, direction)] - node.edges[original_direction] = (next_node, distance) - for key, (distance, next_direction) in edge_positions.items(): - self.edge_positions[key].append((node, next_node, original_direction, distance)) - self.edge_paths[node.position, original_direction].append((*key, next_direction)) - - # Call ourselves recursively since we're exploring depth-first - for transitions, node in nodes.items(): - for direction in transitions: - self.explore_branch(node, get_new_position(position, direction), direction) - - - # Create a tree observation for each agent, based on the graph we created earlier - - def get_many(self, handles = []): - self.nodes_with_agents_going, self.edges_with_agents_going = {}, defaultdict(dict) - self.nodes_with_agents_coming, self.edges_with_agents_coming = {}, defaultdict(dict) - self.nodes_with_malfunctions, self.edges_with_malfunctions = {}, defaultdict(dict) - self.nodes_with_departures, self.edges_with_departures = {}, defaultdict(dict) - - # Create some lookup tables that we can use later to figure out how far away the agents are from each other. - for agent in self.env.agents: - if agent.status == RailAgentStatus.READY_TO_DEPART and agent.initial_position: - for direction in range(4): - if (*agent.initial_position, direction) in self.graph: - self.nodes_with_departures[(*agent.initial_position, direction)] = 1 - - for start, _, start_direction, distance in self.edge_positions[(*agent.initial_position, direction)]: - self.edges_with_departures[(*start.position, start_direction)][agent.handle] = distance - - if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position: - agent_key = (*agent.position, agent.direction) - for direction in range(4): - # # Check the nodes - if (*agent.position, direction) in self.graph: - node_dict = self.nodes_with_agents_going if direction == agent.direction else self.nodes_with_agents_coming - node_dict[(*agent.position, direction)] = agent.speed_data['speed'] - - # if len(self.graph[agent_key].edges) > 1: - # exit_direction = get_direction(agent.direction, agent.speed_data['transition_action_on_cellexit']) - # if agent.speed_data['position_fraction'] == 0 or exit_direction not in self.graph[agent_key].edges: # Agent still has options - # self.nodes_with_agents_going[(*agent.position, direction)] = agent.speed_data['speed'] - # else: # Agent has already decided - # coming_direction = (exit_direction + 2) % 4 - # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going - # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] - # else: - # exit_direction = first(self.graph[agent_key].edges.keys()) - # coming_direction = (exit_direction + 2) % 4 - # node_dict = self.nodes_with_agents_coming if direction == coming_direction else self.nodes_with_agents_going - # node_dict[(*agent.position, direction)] = agent.speed_data['speed'] - - # Check the edges - if agent_key in self.edge_positions: - exit_direction = first(self.get_possible_transitions(agent.position, agent.direction)) - coming_direction = (exit_direction + 2) % 4 - edge_dict = self.edges_with_agents_coming if direction == coming_direction else self.edges_with_agents_going - if direction == agent.direction or direction == coming_direction: - for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: - edge_distance = distance if direction == agent.direction else start.edges[start_direction][1] - distance - edge_dict[(*start.position, start_direction)][agent.handle] = (distance, agent.speed_data['speed']) - - - # Check for malfunctions - if agent.malfunction_data['malfunction']: - if (*agent.position, direction) in self.graph: - self.nodes_with_malfunctions[(*agent.position, direction)] = agent.malfunction_data['malfunction'] - - for start, _, start_direction, distance in self.edge_positions[(*agent.position, direction)]: - self.edges_with_malfunctions[(*start.position, start_direction)][agent.handle] = \ - (distance, agent.malfunction_data['malfunction']) - - return super().get_many(handles) - - - # Compute the observation for a single agent - def get(self, handle): - agent = self.env.agents[handle] - visited_cells = set() - - if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: - agent_position = agent.position - elif agent.status == RailAgentStatus.DONE: - agent_position = agent.target - else: return None - - # The root node contains information about the agent itself - children = { x: -np.inf for x in ACTIONS } - dist_min_to_target = self.env.distance_map.get()[(handle, *agent_position, agent.direction)] - agent_malfunctioning, agent_speed = agent.malfunction_data['malfunction'], agent.speed_data['speed'] - root_tree_node = Node(0, 0, 0, 0, 0, 0, dist_min_to_target, 0, 0, agent_malfunctioning, agent_speed, 0, children) - - # Then we build out the tree by exploring from this node - key = (*agent_position, agent.direction) - if key in self.graph: # If we're sitting on a junction, branch out immediately - node = self.graph[key] - if len(node.edges) > 1: # Major node - for direction in self.graph[key].edges.keys(): - root_tree_node.childs[get_action(agent.direction, direction)] = \ - self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) - else: # Minor node - direction = first(self.get_possible_transitions(node.position, agent.direction)) - root_tree_node.childs['F'] = self.get_tree_branch(agent, node, direction, visited_cells, 0, 1) - - else: # Just create a single child in the forward direction - prev_node, next_node, direction, distance = first(self.edge_positions[key]) - root_tree_node.childs['F'] = self.get_tree_branch(agent, prev_node, direction, visited_cells, -distance, 1) - - self.env.dev_obs_dict[handle] = visited_cells - - return root_tree_node - - - # Get the next tree node, starting from `node`, facing `orientation`, and moving in `direction`. - def get_tree_branch(self, agent, node, direction, visited_cells, total_distance, depth): - visited_cells.add((*node.position, 0)) - next_node, distance = node.edges[direction] - original_position = node.position - - targets, agents, minor_nodes = [], [], [] - edge_length, max_malfunction_length = 0, 0 - num_agents_same_direction, num_agents_other_direction = 0, 0 - distance_to_minor_node, distance_to_other_agent = np.inf, np.inf - distance_to_own_target, distance_to_other_target = np.inf, np.inf - min_agent_speed, num_agent_departures = 1.0, 0 - - # Skip ahead until we get to a major node, logging any agents on the tracks along the way - while True: - path = self.edge_paths.get((node.position, direction), []) - orientation = path[-1][-1] if path else direction - dist = total_distance + edge_length - key = (*node.position, direction) - next_key = (*next_node.position, orientation) - - visited_cells.update(path) - visited_cells.add((*next_node.position, 0)) - - # Check for other agents on the junctions up ahead - if next_key in self.nodes_with_agents_going: - num_agents_same_direction += 1 - # distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) - min_agent_speed = min(min_agent_speed, self.nodes_with_agents_going[next_key]) - - if next_key in self.nodes_with_agents_coming: - num_agents_other_direction += 1 - distance_to_other_agent = min(distance_to_other_agent, edge_length + distance) - - if next_key in self.nodes_with_departures: - num_agent_departures += 1 - if next_key in self.nodes_with_malfunctions: - max_malfunction_length = max(max_malfunction_length, self.nodes_with_malfunctions[next_key]) - - # Check for other agents along the tracks up ahead - for d, s in self.edges_with_agents_going[key].values(): - if dist + d > 0: - num_agents_same_direction += 1 - min_agent_speed = min(min_agent_speed, s) - # distance_to_other_agent = min(distance_to_other_agent, edge_length + d) - - for d, _ in self.edges_with_agents_coming[key].values(): - if dist + d > 0: - num_agents_other_direction += 1 - distance_to_other_agent = min(distance_to_other_agent, edge_length + d) - - for d in self.edges_with_departures[key].values(): - if dist + d > 0: - num_agent_departures += 1 - - for d, t in self.edges_with_malfunctions[key].values(): - if dist + d > 0: - max_malfunction_length = max(max_malfunction_length, t) - - # Check for target nodes up ahead - if next_node.is_target: - if self.is_own_target(agent, next_node): - distance_to_own_target = min(distance_to_own_target, edge_length + distance) - else: distance_to_other_target = min(distance_to_other_target, edge_length + distance) - - # Move on to the next node - node = next_node - edge_length += distance - - if len(node.edges) == 1 and not self.is_own_target(agent, node): # This is a minor node, keep exploring - direction, (next_node, distance) = first(node.edges.items()) - if not node.is_target: - distance_to_minor_node = min(distance_to_minor_node, edge_length) - else: break - - # Create a new tree node and populate its children - if depth < self.max_depth: - children = { x: -np.inf for x in ACTIONS } - if not self.is_own_target(agent, node): - for direction in node.edges.keys(): - children[get_action(orientation, direction)] = \ - self.get_tree_branch(agent, node, direction, visited_cells, total_distance + edge_length, depth + 1) - - else: children = {} - - return Node(dist_own_target_encountered=total_distance + distance_to_own_target, - dist_other_target_encountered=total_distance + distance_to_other_target, - dist_other_agent_encountered=total_distance + distance_to_other_agent, - dist_potential_conflict=np.inf, - dist_unusable_switch=total_distance + distance_to_minor_node, - dist_to_next_branch=total_distance + edge_length, - dist_min_to_target=self.env.distance_map.get()[(agent.handle, *node.position, orientation)] or 0, - num_agents_same_direction=num_agents_same_direction, - num_agents_opposite_direction=num_agents_other_direction, - num_agents_malfunctioning=max_malfunction_length, - speed_min_fractional=min_agent_speed, - num_agents_ready_to_depart=num_agent_departures, - childs=children) - - - # Helper functions - - def get_possible_transitions(self, position, direction): - return [i for i, allowed in enumerate(self.env.rail.get_transitions(*position, direction)) if allowed] - - def get_all_transitions(self, position): - bit_groups = f'{self.env.rail.get_full_transitions(*position):019_b}'.split("_") - return [tuple(i for i, allowed in enumerate(bits) if allowed == '1') for bits in bit_groups] - - def is_junction(self, position): - return any(len(transitions) > 1 for transitions in self.get_all_transitions(position)) - - def is_target(self, position): - return position in self.target_positions - - def is_own_target(self, agent, node): - return agent.target == node.position