Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
e4cfa1d
Added visualization for validation
Midren Oct 17, 2022
b2d2769
working dqn
Midren Oct 17, 2022
3a358e6
Added VAE implementation
Midren Oct 26, 2022
aeb052d
Added dm_control integration
Midren Nov 2, 2022
21c179f
Added VQ-VAE, tested on CFAR10
Midren Nov 4, 2022
0fee658
Added clusterized sampling, and sampling of observations instead of s…
Midren Nov 5, 2022
7907e33
Written skeleton for Dreamer and important future notes
Midren Nov 8, 2022
a820328
Implemented World Model mainloop and loss calculation
Midren Nov 8, 2022
ae20fcb
Added generation of imaginative rollouts
Midren Nov 10, 2022
4c3f249
Implemented ActorCritic training in Latent Space
Midren Nov 10, 2022
a8089b9
Added Env abstraction for different env support, and action transform
Midren Nov 14, 2022
cc038c1
Rewritten replay buffer to more memory efficient and fixed cluster sa…
Midren Nov 15, 2022
2266ea5
Small improvements & fixes
Midren Nov 15, 2022
24923e9
Added dist probs & planning visualization
Midren Nov 15, 2022
b5ed8b1
Performance improvement
Midren Dec 4, 2022
769e308
Fixes
Midren Dec 5, 2022
5bbc60c
Fixes
Midren Dec 6, 2022
03a0ee7
Various fixes and refactorings
Midren Dec 6, 2022
5932594
Fix reconstruction & improve hist viz
Midren Dec 6, 2022
6bfb34b
Add PersistentReplayBuffer based on WebDataset
Midren Dec 7, 2022
c88cd04
add async iteration
Midren Dec 9, 2022
91e96c4
Improved the performance of training
Midren Dec 18, 2022
3c46b05
Added mock env, which just iterates over different colors
Midren Jan 7, 2023
90bb2a4
Fixed clustering of samples and correct input for world-model components
Midren Jan 6, 2023
faf35b9
Fix training for reconstruction
Midren Jan 13, 2023
224fd52
Fixed discount factor (invert is_finished), add lamda_return tests
Midren Jan 14, 2023
ce91566
Changed discount factor to enforce 0 after first 0, fixed edge cases …
Midren Jan 14, 2023
c4809d7
Minor fixes, which doesn't fix anything
Midren Jan 14, 2023
eacce41
Added metrics for AC
Midren Jan 14, 2023
6ac194f
WIP
Midren Jan 19, 2023
3c4d94e
Using Tanh/Trunc distributions from torchrl, fixed number of layers i…
Midren Jan 19, 2023
9a43e7f
It is working! (confirmed for cheetah and walker)
Midren Jan 27, 2023
b12adad
Changed income data to account for 1 frame, fixed rewards/discount_fa…
Midren Jan 27, 2023
42b448b
Changed video logging (5 steps to update determ state, show error)
Midren Jan 28, 2023
9d3d00b
Added state abstraction which simplified a lot !
Midren Jan 28, 2023
0dabd54
Fixed video log
Midren Jan 28, 2023
d8ba8c0
Fixes
Midren Jan 29, 2023
1164ba6
Fixed differences with original DreamerV2
Midren Feb 1, 2023
a6a4e44
Fixed critical bug with nullified action in training
Midren Feb 3, 2023
b689a22
added reloading from checkpoint
Midren Feb 17, 2023
0958015
Added VQ-VAE as a step after each RNN step
Midren Feb 18, 2023
7ea44bf
Added parameter for disabling discrete rssm
Midren Feb 18, 2023
d534717
Added Crafter environment
Midren Feb 18, 2023
74c1a3c
Added logging of gradients
Midren Feb 19, 2023
07ee10f
Added missing Dreamer parts, required for Crafter
Midren Feb 20, 2023
b21467a
Added scheduler
Midren Feb 21, 2023
abb8a20
Fixed loss in quantizer
Midren Mar 4, 2023
383ebfa
Fixes
Midren Mar 9, 2023
6366540
Small fixes
Midren Mar 23, 2023
3b9906f
Added reconstruction of DINO features
Midren Mar 23, 2023
0d3a6b0
tuned parametrs for dino features, added convolutional vit decoder
Midren Apr 4, 2023
2d2b658
Added upsample before ViT
Midren Apr 4, 2023
0a5ef3a
Added 16 patch size with 224 px variant
Midren Apr 5, 2023
3941234
Fixes to work with torch.compile
Midren Apr 7, 2023
f2c0906
Working slot attention implementation, with 224px resize
Midren Apr 7, 2023
9d2839f
Added DINO features reconstruction for slot attention
Midren Apr 11, 2023
f060b5a
64px encoder, add possibility to use prev_slots
Midren Apr 11, 2023
5034573
Added first implementation of slot attention inside Dreamer
Midren Apr 15, 2023
d99db5f
Renamed
Midren Apr 15, 2023
ad411de
Added dockerfiles with installation guide
Midren Apr 23, 2023
7d2f6b1
Slot attention debug
Midren Apr 23, 2023
e5b48a2
Added parameter sweep configuration to run on each of gpu
Midren Apr 23, 2023
e4b2ec4
Added bigger encoder
Midren Apr 24, 2023
1f66d01
parameters without decoder collapse
Midren Apr 25, 2023
ff83255
Refactored to have separate entities in hydra
Midren Apr 26, 2023
01db417
Refactored to have simultaneously multiple version of world models (d…
Midren Apr 27, 2023
6b1f4b5
Get slotted crafter working with DINO features
Midren Jun 3, 2023
a4d0dad
Added per-slot statistics and all slot dynamics learning
Midren Jun 8, 2023
370cfba
Added infra to precalc observation data inside replay buffer
Midren Jun 20, 2023
26f83a7
Added scheduling for attention
Midren Jun 21, 2023
f8846b6
Added hard/soft mixing and per slot mse mean loss
Midren Jun 21, 2023
c181149
Added corresponding configs
Midren Jun 24, 2023
cc362ae
Rewritten hard mixing and simplified per slot reconstruction loss
Midren Jun 24, 2023
5fcbc56
Performance improvements
Midren Jul 3, 2023
42f04c9
Fixed that layer normalization was not applying
Midren Jul 3, 2023
ef28130
Moved the reward transformation inside dreamer for correct tensorboar…
Midren Jul 3, 2023
60fa7c4
Added wandb integration
Midren Jul 4, 2023
4cac3de
Added Atari support and fixed encode_vit for combined
Midren Jul 9, 2023
bdda685
Added option for choosing 224 vs 64 dino features
Midren Jul 12, 2023
54a31e6
Redo dino decoder
Midren Jul 13, 2023
9d169ba
Changed prev_slots to use same random vector
Midren Jul 15, 2023
c3ececa
Rewritten slot attention to calculate out of loop
Midren Jul 15, 2023
4cfaed3
Fix KL div calculation
Midren Jul 16, 2023
a7ff93b
Added pos encoding for combined slot dreamer
Midren Jul 16, 2023
e955bc7
fixup! Added pos encoding for combined slot dreamer
Midren Jul 16, 2023
f5d6413
Changed pos enc to sin-based
Midren Jul 17, 2023
9849dca
Fixed attention RSSM, added updated determ for more stable learning
Midren Jul 17, 2023
b87c194
Fix determ updated
Midren Jul 18, 2023
0e610b2
Added attention block before image update
Midren Jul 18, 2023
e80af48
Added identity decay for embed attention
Midren Jul 18, 2023
9bf4054
Added vit error visualization
Midren Jul 18, 2023
8c8e172
fixup! Added vit error visualization
Midren Jul 18, 2023
bdd8402
Higher kl loss
Midren Jul 19, 2023
276a872
Increase kl loss and batch size
Midren Jul 20, 2023
706e973
Added score calculation for crafter
Midren Jul 21, 2023
70dcd58
fixing differences with baselines
Midren Jul 23, 2023
aaa5dc4
Changes to the decoder
Midren Jul 25, 2023
4e437f2
Try newer dino-only attention slotted
Midren Jul 25, 2023
5829304
Train with f16, improve slot attention
Midren Aug 1, 2023
384b253
Pytorch 2.0 support, Spatial Broadcast Decoder and fixes
Midren Aug 11, 2023
86f0ba2
Added slot implementation after the dynamics model
Midren Aug 11, 2023
8369018
Added per slot rec loss
Midren Aug 13, 2023
3869755
Fixed per slot rec loss
Midren Aug 13, 2023
ac5998d
Totally fixed per slot rec for post slot
Midren Aug 13, 2023
72449dc
Add config option to tweak spatial decoder
Midren Aug 18, 2023
edad471
No image predictor in dino only, new envs
Midren Sep 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
__pycache__/
.vscode/
runs/
poetry.lock
9 changes: 8 additions & 1 deletion .vimspector.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@
"Run main": {
"extends": "python-base",
"configuration": {
"program": "main.py",
"program": "rl_sandbox/train.py",
"args": ["logger.type='tensorboard'", "training.prefill=0", "training.batch_size=4"]
}
},
"Run dino": {
"extends": "python-base",
"configuration": {
"program": "rl_sandbox/vision/slot_attention.py",
"args": []
}
}
Expand Down
61 changes: 61 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
ARG BASE_IMAGE=nvidia/cudagl:11.3.0-devel
FROM $BASE_IMAGE

ARG USER_ID
ARG GROUP_ID
ARG USER_NAME=user

RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y ssh gcc g++ gdb clang rsync tar python sudo git ffmpeg ninja-build locales \
&& apt-get clean \
&& sudo rm -rf /var/lib/apt/lists/*

RUN ( \
echo 'LogLevel DEBUG2'; \
echo 'PermitRootLogin yes'; \
echo 'PasswordAuthentication yes'; \
echo 'Subsystem sftp /usr/lib/openssh/sftp-server'; \
) > /etc/ssh/sshd_config_test_clion \
&& mkdir /run/sshd

RUN groupadd -g ${GROUP_ID} ${USER_NAME} && \
useradd -u ${USER_ID} -g ${GROUP_ID} -s /bin/bash -m ${USER_NAME} && \
yes password | passwd ${USER_NAME} && \
usermod -aG sudo ${USER_NAME} && \
echo "${USER_NAME} ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/user && \
chmod 440 /etc/sudoers

USER ${USER_NAME}

RUN git clone https://github.com/Midren/dotfiles /home/${USER_NAME}/.dotfiles && \
/home/${USER_NAME}/.dotfiles/install-profile ubuntu-cli

RUN git config --global user.email "milromchuk@gmail.com" && \
git config --global user.name "Roman Milishchuk"

USER root

RUN apt-get update \
&& apt-get install -y software-properties-common curl \
&& add-apt-repository -y ppa:deadsnakes/ppa \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y python3.10 python3.10-dev python3.10-venv \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 \
&& apt-get clean \
&& sudo rm -rf /var/lib/apt/lists/*

RUN sudo update-alternatives --install /usr/bin/python3 python /usr/bin/python3.10 1 \
&& sudo update-alternatives --install /usr/bin/python python3 /usr/bin/python3.10 1

USER ${USER_NAME}
WORKDIR /home/${USER_NAME}/

RUN mkdir /home/${USER_NAME}/rl_sandbox

COPY pyproject.toml /home/${USER_NAME}/rl_sandbox/pyproject.toml
COPY rl_sandbox /home/${USER_NAME}/rl_sandbox/rl_sandbox

RUN cd /home/${USER_NAME}/rl_sandbox \
&& python3.10 -m pip install --no-cache-dir -e . \
&& rm -Rf /home/${USER_NAME}/.cache/pip


23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
## RL sandbox

## Run

Build docker:
```sh
docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) --build-arg USER_NAME=$USER -t dreamer .
```

Run docker with tty:
```sh
docker run --gpus 'all' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer zsh
```

Run training inside docker on gpu 0:
```sh
docker run --gpus 'device=0' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer python3 rl_sandbox/train.py --config-name config_dino
```

To run dreamer version with slot attention use:
```
rl_sandbox/train.py --config-name config_slotted
```
4 changes: 0 additions & 4 deletions config/agent/dqn_agent.yaml

This file was deleted.

12 changes: 0 additions & 12 deletions config/config.yaml

This file was deleted.

49 changes: 0 additions & 49 deletions main.py

This file was deleted.

31 changes: 30 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,38 @@ version = "0.1.0"
description = 'Sandbox for my RL experiments'
authors = ['Roman Milishchuk <milishchuk.roman@gmail.com>']
packages = [{include = 'rl_sandbox'}]
# add config directory as package data

# TODO: add yapf and isort as development dependencies
[tool.poetry.dependencies]
python = "^3.10"
numpy = '*'
nptyping = '*'
gym = "^0.26.1"
gym = "0.25.0" # crafter requires old step api
pygame = '*'
moviepy = '*'
torchvision = '*'
torch = '^2.0'
tensorboard = '^2.0'
dm-control = '^1.0.0'
unpackable = '^0.0.4'
hydra-core = "^1.2.0"
matplotlib = "^3.0.0"
webdataset = "^0.2.20"
jaxtyping = '^0.2.0'
lovely_tensors = '^0.1.10'
torchshow = '^0.5.0'
crafter = '^1.8.0'
wandb = '*'
flatten-dict = '*'
hydra-joblib-launcher = "*"

[tool.yapf]
based_on_style = "pep8"
column_limit = 90

[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
]

2 changes: 2 additions & 0 deletions rl_sandbox/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from rl_sandbox.agents.dqn import DqnAgent
from rl_sandbox.agents.dreamer_v2 import DreamerV2
31 changes: 18 additions & 13 deletions rl_sandbox/agents/dqn_agent.py → rl_sandbox/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,49 @@
from rl_sandbox.agents.rl_agent import RlAgent
from rl_sandbox.utils.fc_nn import fc_nn_generator
from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State,
States, TerminationFlag)
States, TerminationFlags)


class DqnAgent(RlAgent):
def __init__(self, actions_num: int,
obs_space_num: int,
hidden_layer_size: int,
num_layers: int,
discount_factor: float):
discount_factor: float,
device_type: str = 'cpu'):
self.gamma = discount_factor
self.value_func = fc_nn_generator(obs_space_num,
actions_num,
hidden_layer_size,
num_layers)
num_layers,
torch.nn.ReLU).to(device_type)
self.optimizer = torch.optim.Adam(self.value_func.parameters(), lr=1e-3)
self.loss = torch.nn.MSELoss()
self.device_type = device_type

def get_action(self, obs: State) -> Action:
return np.array(torch.argmax(self.value_func(torch.from_numpy(obs)), dim=1))
return np.array(torch.argmax(self.value_func(torch.from_numpy(obs.reshape(1, -1)).to(self.device_type)), dim=1).detach().cpu())[0]

def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlag):
def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags):
# Bellman error: MSE( (r + gamma * max_a Q(S_t+1, a)) - Q(s_t, a) )
# check for is finished

s = torch.from_numpy(s)
a = torch.from_numpy(a)
r = torch.from_numpy(r)
next = torch.from_numpy(next)
is_finished = torch.from_numpy(is_finished)
s = torch.from_numpy(s).to(self.device_type)
a = torch.from_numpy(a).to(self.device_type)
r = torch.from_numpy(r).to(self.device_type)
next = torch.from_numpy(next).to(self.device_type)
is_finished = torch.from_numpy(is_finished).to(self.device_type)

# TODO: normalize input
# TODO: double dqn with target network
values = self.value_func(next)
indeces = torch.argmax(values, dim=1)
x = r + (self.gamma * torch.gather(values, dim=1, index=indeces.unsqueeze(1)).squeeze(1)) * torch.logical_not(is_finished)
target = r + (self.gamma * torch.gather(values, dim=1, index=indeces.unsqueeze(1)).squeeze(1)) * torch.logical_not(is_finished)

loss = self.loss(x, torch.gather(self.value_func(s), dim=1, index=a).squeeze(1))
loss = self.loss(torch.gather(self.value_func(s), dim=1, index=a).squeeze(1), target.detach())

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

return loss.detach()
return {'loss': loss.detach().cpu()}
1 change: 1 addition & 0 deletions rl_sandbox/agents/dreamer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .common import *
Loading