diff --git a/.gitignore b/.gitignore index c18dd8d..7a5b954 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ __pycache__/ +.vscode/ +runs/ +poetry.lock diff --git a/.vimspector.json b/.vimspector.json index 5138480..5bedd5f 100644 --- a/.vimspector.json +++ b/.vimspector.json @@ -39,7 +39,14 @@ "Run main": { "extends": "python-base", "configuration": { - "program": "main.py", + "program": "rl_sandbox/train.py", + "args": ["logger.type='tensorboard'", "training.prefill=0", "training.batch_size=4"] + } + }, + "Run dino": { + "extends": "python-base", + "configuration": { + "program": "rl_sandbox/vision/slot_attention.py", "args": [] } } diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..cd179cf --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +ARG BASE_IMAGE=nvidia/cudagl:11.3.0-devel +FROM $BASE_IMAGE + +ARG USER_ID +ARG GROUP_ID +ARG USER_NAME=user + +RUN apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ssh gcc g++ gdb clang rsync tar python sudo git ffmpeg ninja-build locales \ + && apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +RUN ( \ + echo 'LogLevel DEBUG2'; \ + echo 'PermitRootLogin yes'; \ + echo 'PasswordAuthentication yes'; \ + echo 'Subsystem sftp /usr/lib/openssh/sftp-server'; \ + ) > /etc/ssh/sshd_config_test_clion \ + && mkdir /run/sshd + +RUN groupadd -g ${GROUP_ID} ${USER_NAME} && \ + useradd -u ${USER_ID} -g ${GROUP_ID} -s /bin/bash -m ${USER_NAME} && \ + yes password | passwd ${USER_NAME} && \ + usermod -aG sudo ${USER_NAME} && \ + echo "${USER_NAME} ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/user && \ + chmod 440 /etc/sudoers + +USER ${USER_NAME} + +RUN git clone https://github.com/Midren/dotfiles /home/${USER_NAME}/.dotfiles && \ + /home/${USER_NAME}/.dotfiles/install-profile ubuntu-cli + +RUN git config --global user.email "milromchuk@gmail.com" && \ + git config --global user.name "Roman Milishchuk" + +USER root + +RUN apt-get update \ + && apt-get install -y software-properties-common curl \ + && add-apt-repository -y ppa:deadsnakes/ppa \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y python3.10 python3.10-dev python3.10-venv \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 \ + && apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +RUN sudo update-alternatives --install /usr/bin/python3 python /usr/bin/python3.10 1 \ + && sudo update-alternatives --install /usr/bin/python python3 /usr/bin/python3.10 1 + +USER ${USER_NAME} +WORKDIR /home/${USER_NAME}/ + +RUN mkdir /home/${USER_NAME}/rl_sandbox + +COPY pyproject.toml /home/${USER_NAME}/rl_sandbox/pyproject.toml +COPY rl_sandbox /home/${USER_NAME}/rl_sandbox/rl_sandbox + +RUN cd /home/${USER_NAME}/rl_sandbox \ + && python3.10 -m pip install --no-cache-dir -e . \ + && rm -Rf /home/${USER_NAME}/.cache/pip + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..0f2851c --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +## RL sandbox + +## Run + +Build docker: +```sh +docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) --build-arg USER_NAME=$USER -t dreamer . +``` + +Run docker with tty: +```sh +docker run --gpus 'all' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer zsh +``` + +Run training inside docker on gpu 0: +```sh +docker run --gpus 'device=0' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer python3 rl_sandbox/train.py --config-name config_dino +``` + +To run dreamer version with slot attention use: +``` +rl_sandbox/train.py --config-name config_slotted +``` diff --git a/config/agent/dqn_agent.yaml b/config/agent/dqn_agent.yaml deleted file mode 100644 index 327bf96..0000000 --- a/config/agent/dqn_agent.yaml +++ /dev/null @@ -1,4 +0,0 @@ -name: dqn -hidden_layer_size: 32 -num_layers: 2 -discount_factor: 0.99 diff --git a/config/config.yaml b/config/config.yaml deleted file mode 100644 index 4f14d23..0000000 --- a/config/config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -defaults: - - agent/dqn_agent - -env: CartPole-v1 -seed: 42 - -training: - epochs: 5000 - batch_size: 128 - -validation: - rollout_num: 5 diff --git a/main.py b/main.py deleted file mode 100644 index 0936874..0000000 --- a/main.py +++ /dev/null @@ -1,49 +0,0 @@ -import hydra -from omegaconf import DictConfig, OmegaConf - -from rl_sandbox.agents.dqn_agent import DqnAgent -from rl_sandbox.utils.replay_buffer import ReplayBuffer -from rl_sandbox.utils.rollout_generation import collect_rollout, fillup_replay_buffer, collect_rollout_num - -from torch.utils.tensorboard.writer import SummaryWriter -import numpy as np - -import gym - -@hydra.main(version_base="1.2", config_path='config', config_name='config') -def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) - - env = gym.make(cfg.env) - - buff = ReplayBuffer() - # FIXME: samples should be also added afterwards - fillup_replay_buffer(env, buff, cfg.training.batch_size) - - # INFO: currently supports only discrete action space - agent_params = {**cfg.agent} - agent_name = agent_params.pop('name') - agent = DqnAgent(obs_space_num=env.observation_space.shape[0], - actions_num=env.action_space.n, - **agent_params, - ) - - writer = SummaryWriter() - - for epoch_num in range(cfg.training.epochs): - # TODO: add exploration and adding data to buffer at each step - - s, a, r, n, f = buff.sample(cfg.training.batch_size) - - loss = agent.train(s, a, r, n, f) - writer.add_scalar('train/loss', loss, epoch_num) - - if epoch_num % 100 == 0: - rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) - average_len = np.mean(list(map(lambda x: len(x[0]), rollouts))) - writer.add_scalar('val/average_len', average_len, epoch_num) - - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index e986927..bd601c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,9 +8,38 @@ version = "0.1.0" description = 'Sandbox for my RL experiments' authors = ['Roman Milishchuk '] packages = [{include = 'rl_sandbox'}] +# add config directory as package data +# TODO: add yapf and isort as development dependencies [tool.poetry.dependencies] python = "^3.10" numpy = '*' nptyping = '*' -gym = "^0.26.1" +gym = "0.25.0" # crafter requires old step api +pygame = '*' +moviepy = '*' +torchvision = '*' +torch = '^2.0' +tensorboard = '^2.0' +dm-control = '^1.0.0' +unpackable = '^0.0.4' +hydra-core = "^1.2.0" +matplotlib = "^3.0.0" +webdataset = "^0.2.20" +jaxtyping = '^0.2.0' +lovely_tensors = '^0.1.10' +torchshow = '^0.5.0' +crafter = '^1.8.0' +wandb = '*' +flatten-dict = '*' +hydra-joblib-launcher = "*" + +[tool.yapf] +based_on_style = "pep8" +column_limit = 90 + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", +] + diff --git a/rl_sandbox/agents/__init__.py b/rl_sandbox/agents/__init__.py new file mode 100644 index 0000000..d1b779f --- /dev/null +++ b/rl_sandbox/agents/__init__.py @@ -0,0 +1,2 @@ +from rl_sandbox.agents.dqn import DqnAgent +from rl_sandbox.agents.dreamer_v2 import DreamerV2 diff --git a/rl_sandbox/agents/dqn_agent.py b/rl_sandbox/agents/dqn.py similarity index 54% rename from rl_sandbox/agents/dqn_agent.py rename to rl_sandbox/agents/dqn.py index 13ba74d..b2fb3b0 100644 --- a/rl_sandbox/agents/dqn_agent.py +++ b/rl_sandbox/agents/dqn.py @@ -4,7 +4,7 @@ from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, - States, TerminationFlag) + States, TerminationFlags) class DqnAgent(RlAgent): @@ -12,36 +12,41 @@ def __init__(self, actions_num: int, obs_space_num: int, hidden_layer_size: int, num_layers: int, - discount_factor: float): + discount_factor: float, + device_type: str = 'cpu'): self.gamma = discount_factor self.value_func = fc_nn_generator(obs_space_num, actions_num, hidden_layer_size, - num_layers) + num_layers, + torch.nn.ReLU).to(device_type) self.optimizer = torch.optim.Adam(self.value_func.parameters(), lr=1e-3) self.loss = torch.nn.MSELoss() + self.device_type = device_type def get_action(self, obs: State) -> Action: - return np.array(torch.argmax(self.value_func(torch.from_numpy(obs)), dim=1)) + return np.array(torch.argmax(self.value_func(torch.from_numpy(obs.reshape(1, -1)).to(self.device_type)), dim=1).detach().cpu())[0] - def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlag): + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): # Bellman error: MSE( (r + gamma * max_a Q(S_t+1, a)) - Q(s_t, a) ) # check for is finished - s = torch.from_numpy(s) - a = torch.from_numpy(a) - r = torch.from_numpy(r) - next = torch.from_numpy(next) - is_finished = torch.from_numpy(is_finished) + s = torch.from_numpy(s).to(self.device_type) + a = torch.from_numpy(a).to(self.device_type) + r = torch.from_numpy(r).to(self.device_type) + next = torch.from_numpy(next).to(self.device_type) + is_finished = torch.from_numpy(is_finished).to(self.device_type) + # TODO: normalize input + # TODO: double dqn with target network values = self.value_func(next) indeces = torch.argmax(values, dim=1) - x = r + (self.gamma * torch.gather(values, dim=1, index=indeces.unsqueeze(1)).squeeze(1)) * torch.logical_not(is_finished) + target = r + (self.gamma * torch.gather(values, dim=1, index=indeces.unsqueeze(1)).squeeze(1)) * torch.logical_not(is_finished) - loss = self.loss(x, torch.gather(self.value_func(s), dim=1, index=a).squeeze(1)) + loss = self.loss(torch.gather(self.value_func(s), dim=1, index=a).squeeze(1), target.detach()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() - return loss.detach() + return {'loss': loss.detach().cpu()} diff --git a/rl_sandbox/agents/dreamer/__init__.py b/rl_sandbox/agents/dreamer/__init__.py new file mode 100644 index 0000000..55e5f84 --- /dev/null +++ b/rl_sandbox/agents/dreamer/__init__.py @@ -0,0 +1 @@ +from .common import * diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py new file mode 100644 index 0000000..bfd0a12 --- /dev/null +++ b/rl_sandbox/agents/dreamer/ac.py @@ -0,0 +1,146 @@ +import typing as t + +import torch +import torch.distributions as td +from torch import nn + +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator + + +class ImaginativeCritic(nn.Module): + + def __init__(self, discount_factor: float, update_interval: int, + soft_update_fraction: float, value_target_lambda: float, latent_dim: int, + layer_norm: bool): + super().__init__() + self.gamma = discount_factor + self.critic_update_interval = update_interval + self.lambda_ = value_target_lambda + self.critic_soft_update_fraction = soft_update_fraction + self._update_num = 0 + + self.critic = fc_nn_generator(latent_dim, + 1, + 400, + 5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.target_critic = fc_nn_generator(latent_dim, + 1, + 400, + 5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.target_critic.requires_grad_(False) + + def update_target(self): + if self._update_num == 0: + self.target_critic.load_state_dict(self.critic.state_dict()) + # for target_param, local_param in zip(self.target_critic.parameters(), + # self.critic.parameters()): + # mix = self.critic_soft_update_fraction + # target_param.data.copy_(mix * local_param.data + + # (1 - mix) * target_param.data) + self._update_num = (self._update_num + 1) % self.critic_update_interval + + def estimate_value(self, z) -> td.Distribution: + return self.critic(z) + + def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): + # Formula is actually slightly different than in paper + # https://github.com/danijar/dreamerv2/issues/25 + v_lambdas = [vs[-1]] + for i in range(rs.shape[0] - 1, -1, -1): + v_lambda = rs[i] + ds[i] * ( + (1 - self.lambda_) * vs[i + 1] + self.lambda_ * v_lambdas[-1]) + v_lambdas.append(v_lambda) + + reversed_indices = torch.arange(len(v_lambdas)-1, -1, -1) + return torch.stack(v_lambdas)[reversed_indices][:-1] + + def lambda_return(self, zs, rs, ds): + vs = self.target_critic(zs).mode + return self._lambda_return(vs, rs, ds) + + def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, + discount_factors: torch.Tensor): + predicted_vs_dist = self.estimate_value(zs.detach()) + losses = { + 'loss_critic': + -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(2) * + discount_factors).mean() + } + metrics = { + 'critic/avg_target_value': self.target_critic(zs).mode.mean(), + 'critic/avg_lambda_value': vs.mean(), + 'critic/avg_predicted_value': predicted_vs_dist.mode.mean() + } + return losses, metrics + + +class ImaginativeActor(nn.Module): + + def __init__(self, latent_dim: int, actions_num: int, is_discrete: bool, + layer_norm: bool, reinforce_fraction: t.Optional[float], + entropy_scale: float): + super().__init__() + self.rho = reinforce_fraction + if self.rho is None: + self.rho = is_discrete + self.eta = entropy_scale + self.actor = fc_nn_generator( + latent_dim, + actions_num if is_discrete else actions_num * 2, + 400, + 5, + layer_norm=layer_norm, + intermediate_activation=nn.ELU, + final_activation=DistLayer('onehot' if is_discrete else 'normal_trunc')) + + def forward(self, z: torch.Tensor) -> td.Distribution: + return self.actor(z) + + def get_action(self, state) -> td.Distribution: + # FIXME: you should be ashamed for such fix for prev_slots + if isinstance(state, tuple): + return self.actor(state[0].combined) + else: + return self.actor(state.combined) + + def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, baseline: torch.Tensor, + discount_factors: torch.Tensor, actions: torch.Tensor): + losses = {} + metrics = {} + action_dists = self.actor(zs.detach()) + advantage = (vs - baseline).detach() + losses['loss_actor_reinforce'] = -(self.rho * action_dists.log_prob( + actions.detach()).unsqueeze(2) * discount_factors * advantage).mean() + if self.rho != 1.0: + losses['loss_actor_dynamics_backprop'] = -((1 - self.rho) * + (vs * discount_factors)).mean() + else: + losses['loss_actor_dynamics_backprop'] = torch.tensor(0) + + def calculate_entropy(dist): + # return dist.base_dist.base_dist.entropy().unsqueeze(2) + return dist.entropy().unsqueeze(2) + + losses['loss_actor_entropy'] = -(self.eta * calculate_entropy(action_dists) * + discount_factors).mean() + losses['loss_actor'] = losses['loss_actor_reinforce'] + losses[ + 'loss_actor_dynamics_backprop'] + losses['loss_actor_entropy'] + + # mean and std are estimated statistically as tanh transformation is used + sample = action_dists.rsample((128,)) + act_avg = sample.mean(0) + metrics['actor/avg_val'] = act_avg.mean() + # metrics['actor/mode_val'] = action_dists.mode.mean() + metrics['actor/mean_val'] = action_dists.mean.mean() + metrics['actor/avg_sd'] = (((sample - act_avg)**2).mean(0).sqrt()).mean() + metrics['actor/min_val'] = sample.min() + metrics['actor/max_val'] = sample.max() + + return losses, metrics diff --git a/rl_sandbox/agents/dreamer/common.py b/rl_sandbox/agents/dreamer/common.py new file mode 100644 index 0000000..1c22064 --- /dev/null +++ b/rl_sandbox/agents/dreamer/common.py @@ -0,0 +1,83 @@ +import torch +from torch import nn +import torch.distributions as td +import numpy as np + +from rl_sandbox.utils.dists import DistLayer + +def get_position_encoding(seq_len, d, n=10000): + P = np.zeros((seq_len, d)) + for k in range(seq_len): + for i in np.arange(int(d/2)): + denominator = np.power(n, 2*i/d) + P[k, 2*i] = np.sin(k/denominator) + P[k, 2*i+1] = np.cos(k/denominator) + return P + +class View(nn.Module): + + def __init__(self, shape): + super().__init__() + self.shape = shape + + def forward(self, x): + return x.view(*self.shape) + + +def Dist(val): + return td.Independent(DistLayer('onehot')(val), 1) + + +class Normalizer(nn.Module): + + def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): + super().__init__() + self.momentum = momentum + self.scale = scale + self.eps = eps + self.register_buffer('mag', torch.ones(1, dtype=torch.float32)) + self.mag.requires_grad = False + + def forward(self, x): + self.update(x) + return (x / (self.mag + self.eps)) * self.scale + + def update(self, x): + self.mag = self.momentum * self.mag + (1 - + self.momentum) * (x.abs().mean()).detach() + + +class GRUCell(nn.Module): + + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, + 3 * hidden_size, + bias=norm is not None, + **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + + diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py new file mode 100644 index 0000000..ce8746e --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -0,0 +1,209 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, View, GRUCell +from rl_sandbox.utils.schedulers import LinearScheduler + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch determ'] + stoch_logits: Float[torch.Tensor, 'seq batch latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch stoch_dim']] = None + + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None) + + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None) + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:2] + (-1,)) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim = 0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs) + + +class Quantize(nn.Module): + + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + embed = torch.randn(dim, n_embed) + self.inp_in = nn.Linear(1024, self.n_embed * self.dim) + self.inp_out = nn.Linear(self.n_embed * self.dim, 1024) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, inp): + # input = self.inp_in(inp).reshape(-1, 1, self.n_embed, self.dim) + input = inp.reshape(-1, 1, self.n_embed, self.dim) + inp = input + flatten = input.reshape(-1, self.dim) + dist = (flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True)) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, + alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ((self.cluster_size + self.eps) / + (n + self.n_embed * self.eps) * n) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) + quantize_out = quantize + diff = 0.25 * (quantize_out.detach() - inp).pow(2).mean() + ( + quantize_out - inp.detach()).pow(2).mean() + quantize = inp + (quantize_out - inp).detach() + + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + +class RSSM(nn.Module): + """ + Recurrent State Space Model + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + """ + + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity): + super().__init__() + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True) + ) + self.determ_recurrent = GRUCell(input_size=hidden_size, hidden_size=hidden_size, norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + + # For observation we do not have ensemble + # FIXME: very bad magic number + img_sz = 4 * 384 # 384*2x2 + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + # self.determ_discretizer = MlpVAE(self.hidden_size) + self.determ_discretizer = Quantize(16, 16) + self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) + self.determ_layer_norm = nn.LayerNorm(hidden_size) + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + return self.ensemble_prior_estimator(prev_determ) + + def on_train_step(self): + pass + + def predict_next(self, + prev_state: State, + action) -> State: + x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action], dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x, prev_state.determ) + if self.discrete_rssm: + determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + determ_post = determ_post.reshape(determ_prior.shape) + determ_post = self.determ_layer_norm(determ_post) + alpha = self.discretizer_scheduler.val + determ_post = alpha * determ_prior + (1-alpha) * determ_post + else: + determ_post, diff = determ_prior, 0 + + # used for KL divergence + predicted_stoch_logits = self.estimate_stochastic_latent(x) + return State(determ_post, predicted_stoch_logits), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1))) + + def forward(self, h_prev: State, embed, + action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff diff --git a/rl_sandbox/agents/dreamer/rssm_slots.py b/rl_sandbox/agents/dreamer/rssm_slots.py new file mode 100644 index 0000000..24ca200 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm_slots.py @@ -0,0 +1,177 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn + +from rl_sandbox.agents.dreamer import Dist, View, GRUCell + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None) + + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None) + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + + @property + def combined_slots(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist( + self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim=0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), stochs) + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, + latent_dim, + hidden_size, + actions_num, + latent_classes, + discrete_rssm, + norm_layer: nn.LayerNorm | nn.Identity, + embed_size=2 * 2 * 384): + super().__init__() + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True)) + self.determ_recurrent = GRUCell(input_size=hidden_size, + hidden_size=hidden_size, + norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + for _ in range(self.ensemble_num) + ]) + + # For observation we do not have ensemble + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + + def on_train_step(self): + pass + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] + # NOTE: Maybe something smarter can be used instead of + # taking only one random between all ensembles + # NOTE: in Dreamer ensemble_num is always 1 + idx = torch.randint(0, self.ensemble_num, ()) + return dists_per_model[0] + + def predict_next(self, prev_state: State, action) -> State: + x = self.pre_determ_recurrent( + torch.concat([ + prev_state.stoch, + action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) + ], + dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(1, 2), + prev_state.determ.flatten(1, 2)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + else: + determ_post, diff = determ_prior, 0 + + # used for KL divergence + predicted_stoch_logits = self.estimate_stochastic_latent(x) + # Size is 1 x B x slots_num x ... + return State(determ_post.reshape(prev_state.determ.shape), + predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State( + prior.determ, + self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( + 1, 2).reshape(prior.stoch_logits.shape)) + + def forward(self, h_prev: State, embed, action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py new file mode 100644 index 0000000..6d10c06 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -0,0 +1,243 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn +import torch.nn.functional as F + +from rl_sandbox.agents.dreamer import Dist, View, GRUCell +from rl_sandbox.utils.schedulers import LinearScheduler + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + pos_enc: t.Optional[Float[torch.Tensor, '1 1 num_slots stoch_dim+determ']] = None + determ_updated: t.Optional[Float[torch.Tensor, 'seq batch num_slots determ']] = None + + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None, + self.pos_enc if self.pos_enc is not None else None) + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None, + self.pos_enc.detach() if self.pos_enc is not None else None) + + @property + def combined(self): + return self.combined_slots.flatten(2, 3) + + @property + def combined_slots(self): + state = torch.concat([self.determ, self.stoch], dim=-1) + if self.pos_enc is not None: + return state + self.pos_enc + else: + return state + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist( + self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim=0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs, + states[0].pos_enc) + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, + latent_dim, + hidden_size, + actions_num, + latent_classes, + discrete_rssm, + norm_layer: nn.LayerNorm | nn.Identity, + full_qk_from: int = 1, + symmetric_qk: bool = False, + attention_block_num: int = 3, + embed_size=2 * 2 * 384): + super().__init__() + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + self.symmetric_qk = symmetric_qk + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True)) + self.determ_recurrent = GRUCell(input_size=hidden_size, + hidden_size=hidden_size, + norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + + # For observation we do not have ensemble + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + + self.hidden_attention_proj = nn.Linear(hidden_size, 3*hidden_size, bias=False) + self.pre_norm = nn.LayerNorm(hidden_size) + + self.fc = nn.Linear(hidden_size, hidden_size) + self.fc_norm = nn.LayerNorm(hidden_size) + + self.attention_scheduler = LinearScheduler(0.0, 1.0, full_qk_from) + self.attention_block_num = attention_block_num + self.att_scale = hidden_size**(-0.5) + self.eps = 1e-8 + + # self.hidden_attention_proj_obs = nn.Linear(embed_size, embed_size) + # self.hidden_attention_proj_obs_state = nn.Linear(hidden_size, embed_size) + # self.pre_norm_obs = nn.LayerNorm(embed_size) + + # self.fc_obs = nn.Linear(embed_size, embed_size) + # self.fc_norm_obs = nn.LayerNorm(embed_size) + + def on_train_step(self): + self.attention_scheduler.step() + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + return self.ensemble_prior_estimator(prev_determ) + + def predict_next(self, prev_state: State, action) -> State: + x = self.pre_determ_recurrent( + torch.concat([ + prev_state.stoch, + action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) + ], + dim=-1)) + + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(1, 2), + prev_state.determ.flatten(1, 2)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + else: + determ_post, diff = determ_prior, 0 + + determ_post = determ_post.reshape(prev_state.determ.shape) + + # TODO: Introduce self-attention block here ! + # Experiment, when only stochastic part is affected and deterministic is not touched + # We keep flow of gradients through determ block, but updating it with stochastic part + for _ in range(self.attention_block_num): + # FIXME: Should the the prev stochastic component also be used ? + q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) + if self.symmetric_qk: + k = q + qk = torch.einsum('lbih,lbjh->lbij', q, k).float() + + attn = torch.softmax(self.att_scale * qk, dim=-1) + self.eps + attn = attn / attn.sum(dim=-1, keepdim=True) + + coeff = self.attention_scheduler.val + attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) + + updates = torch.einsum('lbjd,lbij->lbid', v, attn) + determ_post = determ_post + self.fc(self.fc_norm(updates)) + + self.last_attention = attn.mean(dim=1).squeeze() + + # used for KL divergence + predicted_stoch_logits = self.estimate_stochastic_latent(determ_post.reshape(determ_prior.shape)).reshape(prev_state.stoch_logits.shape) + # Size is 1 x B x slots_num x ... + return State(determ_prior.reshape(prev_state.determ.shape), + predicted_stoch_logits, pos_enc=prev_state.pos_enc, determ_updated=determ_post), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + # k = self.hidden_attention_proj_obs_state(self.pre_norm(prior.determ_updated)) + # q = self.hidden_attention_proj_obs(self.pre_norm_obs(embed)) + # qk = torch.einsum('lbih,lbjh->lbij', q, k) + + # # TODO: Use Gumbel Softmax + # attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) + # attn = attn / attn.sum(dim=-1, keepdim=True) + + # # TODO: Maybe make this a learnable parameter ? + # coeff = min((self.attention_scheduler.val * 5), 1.0) + # attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) + + # embed = torch.einsum('lbij,lbjh->lbih', attn, embed) + # self.embed_attn = attn.squeeze() + + return State( + prior.determ, + self.stoch_net(torch.concat([prior.determ_updated, embed], dim=-1)).flatten( + 1, 2).reshape(prior.stoch_logits.shape), pos_enc=prior.pos_enc) + + def forward(self, h_prev: State, embed, action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff + diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py new file mode 100644 index 0000000..19497b6 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -0,0 +1,217 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn + +from rl_sandbox.agents.dreamer import Dist, View, GRUCell + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + pos_enc: t.Optional[Float[torch.Tensor, '1 1 num_slots stoch_dim+determ']] = None + + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None, + self.pos_enc if self.pos_enc is not None else None) + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None, + self.pos_enc.detach() if self.pos_enc is not None else None) + + @property + def combined(self): + return self.combined_slots.flatten(2, 3) + + @property + def combined_slots(self): + state = torch.concat([self.determ, self.stoch], dim=-1) + if self.pos_enc is not None: + return state + self.pos_enc + else: + return state + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_logits = self.stoch_logits.to(dtype=torch.float32) + self.stoch_ = Dist( + self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim=0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs, + states[0].pos_enc) + + +class GRUCell(nn.Module): + + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, + 3 * hidden_size, + bias=norm is not None, + **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, + latent_dim, + hidden_size, + actions_num, + latent_classes, + discrete_rssm, + norm_layer: nn.LayerNorm | nn.Identity, + slots_num: int, + embed_size=2 * 2 * 384): + super().__init__() + self.slots_num = slots_num + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True)) + self.determ_recurrent = GRUCell(input_size=hidden_size*slots_num, + hidden_size=hidden_size*slots_num, + norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + + def on_train_step(self): + pass + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + return self.ensemble_prior_estimator(prev_determ) + + def predict_next(self, prev_state: State, action) -> State: + x = self.pre_determ_recurrent( + torch.concat([ + prev_state.stoch, + action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) + ], + dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(2, 3), + prev_state.determ.flatten(2, 3)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + else: + determ_post, diff = determ_prior, 0 + + # used for KL divergence + # TODO: Test both options (with slot in batch size and in feature dim) + predicted_stoch_logits = self.estimate_stochastic_latent(x.reshape(prev_state.determ.shape)) + # Size is 1 x B x slots_num x ... + return State(determ_post.reshape(prev_state.determ.shape), + predicted_stoch_logits.reshape(prev_state.stoch_logits.shape), + pos_enc=prev_state.pos_enc), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State( + prior.determ, + self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( + 1, 2).reshape(prior.stoch_logits.shape), pos_enc=prior.pos_enc) + + def forward(self, h_prev: State, embed, action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff + + diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py new file mode 100644 index 0000000..0b706c5 --- /dev/null +++ b/rl_sandbox/agents/dreamer/vision.py @@ -0,0 +1,187 @@ +import torch.distributions as td +from torch import nn +import torch +from rl_sandbox.vision.slot_attention import PositionalEmbedding + + +class Encoder(nn.Module): + + def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, + channel_step=96, + kernel_sizes=[4, 4, 4, 4], + post_conv_num: int = 0, + flatten_output=True, + in_channels=3, + ): + super().__init__() + layers = [] + + for i, k in enumerate(kernel_sizes): + out_channels = 2**i * channel_step + layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + + for k in range(post_conv_num): + layers.append( + nn.Conv2d(out_channels, out_channels, kernel_size=5, padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + + if flatten_output: + layers.append(nn.Flatten()) + self.net = nn.Sequential(*layers) + + def forward(self, X): + return self.net(X) + + +class SpatialBroadcastDecoder(nn.Module): + + def __init__(self, + input_size, + norm_layer: nn.GroupNorm | nn.Identity, + kernel_sizes = [3, 3, 3], + out_image=(64, 64), + channel_step=64, + output_channels=3, + return_dist=True): + + super().__init__() + layers = [] + self.channel_step = channel_step + self.in_channels = 2*self.channel_step + self.out_shape = out_image + self.positional_augmenter = PositionalEmbedding(self.in_channels, out_image) + + in_channels = self.in_channels + self.convin = nn.Linear(input_size, in_channels) + self.return_dist = return_dist + + for i, k in enumerate(kernel_sizes): + out_channels = channel_step + if i == len(kernel_sizes) - 1: + out_channels = output_channels + layers.append(nn.Conv2d(in_channels, + out_channels, + kernel_size=k, + padding='same')) + else: + layers.append(nn.Conv2d(in_channels, + out_channels, + kernel_size=k, + padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, self.in_channels, 1, 1) + x = torch.tile(x, self.out_shape) + x = self.positional_augmenter(x) + if self.return_dist: + return td.Independent(td.Normal(self.net(x), 1.0), 3) + else: + return self.net(x) + +class Decoder(nn.Module): + + def __init__(self, + input_size, + norm_layer: nn.GroupNorm | nn.Identity, + kernel_sizes=[5, 5, 6, 6], + channel_step = 48, + output_channels=3, + conv_kernel_sizes=[], + return_dist=True): + super().__init__() + layers = [] + self.channel_step = channel_step + self.in_channels = 2 **(len(kernel_sizes)+1) * self.channel_step + in_channels = self.in_channels + self.convin = nn.Linear(input_size, in_channels) + self.return_dist = return_dist + + for i, k in enumerate(kernel_sizes): + out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step + if i == len(kernel_sizes) - 1: + out_channels = output_channels + layers.append(nn.ConvTranspose2d(in_channels, + output_channels, + kernel_size=k, + stride=2, + output_padding=0)) + else: + layers.append( + nn.ConvTranspose2d(in_channels, + out_channels, + kernel_size=k, + stride=2, + output_padding=0)) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + + for k in conv_kernel_sizes: + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + layers.append( + nn.Conv2d(output_channels, + output_channels, + kernel_size=k, + padding='same')) + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, self.in_channels, 1, 1) + if self.return_dist: + return td.Independent(td.Normal(self.net(x), 1.0), 3) + else: + return self.net(x) + + +class ViTDecoder(nn.Module): + + def __init__(self, + input_size, + norm_layer: nn.GroupNorm | nn.Identity, + kernel_sizes=[5, 5, 5, 3, 3]): + super().__init__() + layers = [] + self.channel_step = 12 + # 2**(len(kernel_sizes)-1)*channel_step + self.convin = nn.Linear(input_size, 32 * self.channel_step) + + in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step + for i, k in enumerate(kernel_sizes): + out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step + if i == len(kernel_sizes) - 1: + out_channels = 3 + layers.append( + nn.ConvTranspose2d(in_channels, + 384, + kernel_size=k, + stride=1, + padding=1)) + else: + layers.append(norm_layer(1, in_channels)) + layers.append( + nn.ConvTranspose2d(in_channels, + out_channels, + kernel_size=k, + stride=2, + padding=2, + output_padding=1)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, 32 * self.channel_step, 1, 1) + return td.Independent(td.Normal(self.net(x), 1.0), 3) diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py new file mode 100644 index 0000000..a8088ab --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -0,0 +1,245 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer, View +from rl_sandbox.agents.dreamer.rssm import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, SpatialBroadcastDecoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, + vit_l2_ratio: float, vit_img_size: int): + super().__init__() + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.state_size = (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size + + self.recurrent_model = RSSM(latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity) + if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) + self.encoder = nn.Sequential( + self.dino_vit, + self.post_vit + ) + else: + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[4, 4, 4, 4], + channel_step=48) + + if decode_vit: + self.dino_predictor = SpatialBroadcastDecoder(self.state_size, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + out_image=(14, 14), + kernel_sizes = [5, 5, 5, 5], + channel_step=self.vit_feat_dim, + output_channels=self.vit_feat_dim, + return_dist=True) + + self.image_predictor = Decoder(self.state_size, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) + obs = ToTensor(obs + 0.5) + with torch.no_grad(): + d_features = self.dino_vit(obs).cpu() + return {'d_features': d_features} + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + # FIXME: rewrite to utilize batch processing + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).mode + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: + if state is None: + state = self.get_initial_state() + embed = self.encoder(obs.unsqueeze(0)) + _, posterior, _ = self.recurrent_model.forward(state, embed.unsqueeze(0), + action) + return posterior + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() + b, _, h, w = obs.shape # s <- BxHxWx3 + + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) + embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2, free_nat = True): + KL_ = torch.distributions.kl_divergence + one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) + # TODO: kl_free_avg is used always + if free_nat: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) + else: + kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() + kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + d_features = additional['d_features'] + + prev_state = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + a_t = a_t * (1 - first_t) + + prior, posterior, diff = self.recurrent_model.forward(prev_state, embed_t, a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + if self.vit_l2_ratio != 1.0: + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = torch.tensor(0, device=obs.device) + x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + + d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + d_rec = -d_pred.log_prob(d_obs).float().mean() + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) + + return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py new file mode 100644 index 0000000..aad0e55 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -0,0 +1,359 @@ +import typing as t + +import torch +import math +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer, View +from rl_sandbox.agents.dreamer.rssm import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, SpatialBroadcastDecoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, + vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, + mask_combination: str, use_reshuffle: bool, per_slot_rec_loss: bool, spatial_decoder: bool): + super().__init__() + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale + self.kl_beta = kl_loss_scale + self.per_slot_rec_loss = per_slot_rec_loss + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.state_size = (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size + + self.recurrent_model = RSSM(latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity) + if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) + self.encoder = nn.Sequential( + self.dino_vit, + self.post_vit + ) + else: + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[4, 4, 4, 4], + channel_step=48) + + self.n_dim = 256 + + if decode_vit: + if spatial_decoder: + self.dino_predictor = SpatialBroadcastDecoder(self.n_dim, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + out_image=(14, 14), + kernel_sizes = [5, 5, 5], + channel_step=self.vit_feat_dim, + output_channels=self.vit_feat_dim+1, + return_dist=False) + else: + self.dino_predictor = Decoder(self.n_dim, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) + + self.slots_num = slots_num + self.mask_combination = mask_combination + self.state_feature_num = (self.state_size//self.n_dim) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (1, self.state_feature_num), channel_last=True) + # TODO: slots will assume permutation-invariance + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, True) + self.use_reshuffle = use_reshuffle + if self.use_reshuffle: + self.state_reshuffle = nn.Sequential(nn.Linear(self.state_size, self.state_feature_num*self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.state_feature_num*self.n_dim, self.state_feature_num*self.n_dim)) + + # self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + # nn.ReLU(inplace=True), + # nn.Linear(self.n_dim, self.n_dim)) + + if not decode_vit: + self.image_predictor = Decoder(self.n_dim, + output_channels=4, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=-4) + case 'hard': + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) + obs = ToTensor(obs + 0.5) + with torch.no_grad(): + d_features = self.dino_vit(obs).cpu() + return {'d_features': d_features} + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + # FIXME: rewrite to utilize batch processing + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).mode + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: + if isinstance(state, tuple): + state = state[0] + if state is None: + state = self.get_initial_state() + embed = self.encoder(obs.unsqueeze(0)) + _, posterior, _ = self.recurrent_model.forward(state, embed.unsqueeze(0), + action) + return posterior, None + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() + b, _, h, w = obs.shape # s <- BxHxWx3 + + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) + embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2, free_nat = True): + KL_ = torch.distributions.kl_divergence + one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) + # TODO: kl_free_avg is used always + if free_nat: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) + else: + kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() + kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + d_features = additional['d_features'] + + prev_state = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + a_t = a_t * (1 - first_t) + + prior, posterior, diff = self.recurrent_model.forward(prev_state, embed_t, a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + if self.use_reshuffle: + state = self.state_reshuffle(posterior.combined.transpose(0, 1)) + else: + state = posterior.combined.transpose(0, 1) + assert state.shape[-1] % self.n_dim == 0 and self.rssm_dim % self.n_dim == 0 + state = state.reshape(*state.shape[:-1], self.state_feature_num, self.n_dim) + state_pos_embedded = self.positional_augmenter_inp(state.unsqueeze(-3)).squeeze(-3) + + state_slots = self.slot_attention(state_pos_embedded.flatten(0, 1), None) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + # x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + # losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + img_mask = self.slot_mask(masks) + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + img_rec = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + img_rec = img_rec.mean() + + decoded_imgs = decoded_imgs * img_mask + else: + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean(dim=0) + losses['loss_reconstruction'] = img_rec + else: + if self.vit_l2_ratio == 1.0: + pass + # decoded_imgs_detached, masks = self.image_predictor(state_slots.flatten(0, 1).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + # img_mask = self.slot_mask(masks) + + img_rec = torch.tensor(0, device=obs.device) + + # if self.per_slot_rec_loss: + # l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + # normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + # img_rec_detached = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + # img_rec_detached = img_rec_detached.mean() + # decoded_imgs_detached = decoded_imgs_detached * img_mask + # else: + # decoded_imgs_detached = decoded_imgs_detached * img_mask + # x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) + + # img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + # losses['loss_reconstruction_img'] = img_rec_detached + else: + decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + img_mask = self.slot_mask(masks) + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + img_rec = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + img_rec = img_rec.mean() + + decoded_imgs = decoded_imgs * img_mask + else: + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + + decoded_feats, masks = self.dino_predictor(state_slots.flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=-3) + feat_mask = self.slot_mask(masks) + + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + + decoded_feats = decoded_feats * feat_mask + + if self.per_slot_rec_loss: + l2_loss = (feat_mask * ((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + d_rec = l2_loss * normalizing_factor + torch.prod(torch.tensor(d_obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + d_rec = d_rec.mean() + else: + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=-4), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() + + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) + + return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py new file mode 100644 index 0000000..f51f6b3 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -0,0 +1,346 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer, View +from rl_sandbox.agents.dreamer.rssm_slots import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, + discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, + decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + mask_combination: str = 'soft'): + super().__init__() + self.use_prev_slots = use_prev_slots + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.mask_combination = mask_combination + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size + + self.n_dim = 192 + + self.recurrent_model = RSSM( + latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity, + embed_size=self.n_dim) + if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) + self.encoder = nn.Sequential( + self.dino_vit, + self.post_vit + ) + else: + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[4, 4], + channel_step=48 * (self.n_dim // 192) * 2, + post_conv_num=2, + flatten_output=False) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) + if self.encode_vit: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) + else: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + + self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim)) + + if decode_vit: + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) + self.image_predictor = Decoder( + rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + output_channels=3+1, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=1) + case 'hard': + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) + obs = ToTensor(obs + 0.5) + else: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + obs = resize(obs) + d_features = self.dino_vit(obs) + return {'d_features': d_features} + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + # Tuple of State-Space state and prev slots + return State( + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.rssm_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes, + self.latent_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes * self.latent_dim, + device=device)), None + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).mode + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: + if state is None or state[0] is None: + state, prev_slots = self.get_initial_state() + else: + if self.use_prev_slots: + state, prev_slots = state + else: + state, prev_slots = state[0], None + if self.encode_vit: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + embed = self.encoder(resize(obs).unsqueeze(0)) + else: + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), + action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() + b, _, h, w = obs.shape # s <- BxHxWx3 + + embed = self.encoder(obs) + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, + self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1), 1)).mean() + kl_rhs = KL_( + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1.detach()), 1)).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + d_features = additional['d_features'] + + prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, + t], a_c[:, t].unsqueeze( + 0 + ), first_c[:, + t].unsqueeze( + 0) + a_t = a_t * (1 - first_t) + + slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) + # FIXME: prev_slots was not used properly, need to rerun test + if self.use_prev_slots: + prev_slots = self.slot_attention.prev_slots + else: + prev_slots = None + + prior, posterior, diff = self.recurrent_model.forward( + prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = 0 + decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs_detached = decoded_imgs_detached * img_mask + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) + feat_mask = self.slot_mask(masks) + decoded_feats = decoded_feats * feat_mask + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + d_rec = -d_pred.log_prob(d_obs).float().mean() + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) + + return losses, posterior, metrics + + diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py new file mode 100644 index 0000000..ed97145 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -0,0 +1,393 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer, View, get_position_encoding +from rl_sandbox.agents.dreamer.rssm_slots_attention import RSSM, State +from rl_sandbox.agents.dreamer.vision import SpatialBroadcastDecoder, Decoder, Encoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, + discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, + decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + full_qk_from: int = 1, + symmetric_qk: bool = False, + attention_block_num: int = 3, + mask_combination: str = 'soft', + per_slot_rec_loss: bool = False, + spatial_decoder: bool = False): + super().__init__() + self.use_prev_slots = use_prev_slots + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.mask_combination = mask_combination + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size + self.per_slot_rec_loss = per_slot_rec_loss + + self.n_dim = 384 + + self.recurrent_model = RSSM( + latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity, + embed_size=self.n_dim, + full_qk_from=full_qk_from, + symmetric_qk=symmetric_qk, + attention_block_num=attention_block_num) + if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + ) + self.encoder = nn.Sequential( + self.dino_vit, + self.post_vit + ) + else: + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[4, 4], + channel_step=48 * (self.n_dim // 192) * 2, + post_conv_num=2, + flatten_output=False) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) + self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) + if self.encode_vit: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) + else: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) + + self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim)) + + if decode_vit: + if spatial_decoder: + self.dino_predictor = SpatialBroadcastDecoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + out_image=(14, 14), + kernel_sizes = [5, 5, 5], + channel_step=self.vit_feat_dim, + output_channels=self.vit_feat_dim+1, + return_dist=False) + else: + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) + + self.image_predictor = Decoder( + rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + output_channels=3+1, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=1) + case 'hard': + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) + obs = ToTensor(obs + 0.5) + else: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + obs = resize(obs) + d_features = self.dino_vit(obs).squeeze() + return {'d_features': d_features} + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + # Tuple of State-Space state and prev slots + return State( + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.rssm_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes, + self.latent_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes * self.latent_dim, + device=device), + self.pos_enc.unsqueeze(0).unsqueeze(0)), None + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).mode + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: + if state is None or state[0] is None: + state, prev_slots = self.get_initial_state() + else: + if self.use_prev_slots: + state, prev_slots = state + else: + state, prev_slots = state[0], None + if self.encode_vit: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + embed = self.encoder(resize(obs).unsqueeze(0)) + else: + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), + action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() + b, _, h, w = obs.shape # s <- BxHxWx3 + + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, + self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1), 1)).mean() + kl_rhs = KL_( + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1.detach()), 1)).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + d_features = additional['d_features'] + + prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + + self.last_attn = torch.zeros((self.slots_num, self.slots_num), device=a_c.device) + + prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size)).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) + + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + slots_t, a_t, first_t = (slots_c[:,t], + a_c[:, t].unsqueeze(0), + first_c[:,t].unsqueeze(0)) + a_t = a_t * (1 - first_t) + + prior, posterior, diff = self.recurrent_model.forward( + prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + self.last_attn += self.recurrent_model.last_attention + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + self.last_attn /= self.cluster_size + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.tensor(0, device=obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec = l2_loss.mean() * normalizing_factor * 8 + else: + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + + losses['loss_reconstruction'] = img_rec + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec = l2_loss.mean() * normalizing_factor * 8 + else: + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = torch.tensor(0, device=obs.device) + decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs_detached = decoded_imgs_detached * img_mask + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec_detached = l2_loss.mean() * normalizing_factor * 8 + else: + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + losses['loss_reconstruction_img'] = img_rec_detached + + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) + feat_mask = self.slot_mask(masks) + + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + + decoded_feats = decoded_feats * feat_mask + if self.per_slot_rec_loss: + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) + # # magic constant that describes the difference between log_prob and mse losses + d_rec = l2_loss.mean() * normalizing_factor * 4 + else: + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() + + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['attention_coeff'] = torch.tensor(self.recurrent_model.attention_scheduler.val) + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) + + return losses, posterior, metrics + diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py new file mode 100644 index 0000000..77b8729 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -0,0 +1,373 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer, View, get_position_encoding +from rl_sandbox.agents.dreamer.rssm_slots_combined import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, + discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, + decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + mask_combination: str = 'soft', + per_slot_rec_loss: bool = False): + super().__init__() + self.use_prev_slots = use_prev_slots + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.mask_combination = mask_combination + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size + self.per_slot_rec_loss = per_slot_rec_loss + + self.n_dim = 192 + + self.recurrent_model = RSSM( + latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity, + slots_num=slots_num, + embed_size=self.n_dim) + if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + ) + self.encoder = nn.Sequential( + self.dino_vit, + self.post_vit + ) + else: + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[4, 4], + channel_step=48 * (self.n_dim // 192) * 2, + post_conv_num=2, + flatten_output=False) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) + self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) + if self.encode_vit: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) + else: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + + self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim)) + + if decode_vit: + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) + self.image_predictor = Decoder( + rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + output_channels=3+1, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=1) + case 'hard': + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) + obs = ToTensor(obs + 0.5) + else: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + obs = resize(obs) + d_features = self.dino_vit(obs).squeeze() + return {'d_features': d_features} + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + # Tuple of State-Space state and prev slots + return State( + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.rssm_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes, + self.latent_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes * self.latent_dim, + device=device), + self.pos_enc.unsqueeze(0).unsqueeze(0)), None + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).mode + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: + if state is None or state[0] is None: + state, prev_slots = self.get_initial_state() + else: + if self.use_prev_slots: + state, prev_slots = state + else: + state, prev_slots = state[0], None + if self.encode_vit: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + embed = self.encoder(resize(obs).unsqueeze(0)) + else: + embed = self.encoder(obs.unsqueeze(0)) + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), + action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() + b, _, h, w = obs.shape # s <- BxHxWx3 + + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, + self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1), 1)).mean() + kl_rhs = KL_( + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1.detach()), 1)).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + d_features = additional['d_features'] + + prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + # slot_pos_enc = self.slot_emb(self.slot_indexer).unsqueeze(0) + # prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size) + slot_pos_enc).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size)).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) + # slots_c = slots_c + slot_pos_enc.unsqueeze(0) + + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + slots_t, a_t, first_t = (slots_c[:,t], + a_c[:, t].unsqueeze(0), + first_c[:,t].unsqueeze(0)) + a_t = a_t * (1 - first_t) + + prior, posterior, diff = self.recurrent_model.forward( + prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.tensor(0, device=obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec = l2_loss.mean() * normalizing_factor * 8 + else: + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + losses['loss_reconstruction'] = img_rec + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec = l2_loss.mean() * normalizing_factor * 8 + else: + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = torch.tensor(0, device=obs.device) + decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + decoded_imgs_detached = decoded_imgs_detached * img_mask + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec_detached = l2_loss.mean() * normalizing_factor * 8 + else: + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + losses['loss_reconstruction_img'] = img_rec_detached + + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) + feat_mask = self.slot_mask(masks) + + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + + decoded_feats = decoded_feats * feat_mask + if self.per_slot_rec_loss: + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) + # # magic constant that describes the difference between log_prob and mse losses + d_rec = l2_loss.mean() * normalizing_factor * 4 + else: + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() + + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) + + return losses, posterior, metrics + diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py new file mode 100644 index 0000000..cbecf7b --- /dev/null +++ b/rl_sandbox/agents/dreamer_v2.py @@ -0,0 +1,245 @@ +import typing as t +from pathlib import Path + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +import torchvision as tv +from unpackable import unpack + +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.replay_buffer import (Action, Observation, + RolloutChunks, EnvStep, Rollout) + +from rl_sandbox.agents.dreamer.world_model import WorldModel, State +from rl_sandbox.agents.dreamer.ac import ImaginativeCritic, ImaginativeActor + + +class DreamerV2(RlAgent): + + def __init__( + self, + obs_space_num: list[int], # NOTE: encoder/decoder will work only with 64x64 currently + clip_rewards: str, + actions_num: int, + world_model: t.Any, + actor: t.Any, + critic: t.Any, + action_type: str, + imagination_horizon: int, + wm_optim: t.Any, + actor_optim: t.Any, + critic_optim: t.Any, + layer_norm: bool, + batch_cluster_size: int, + f16_precision: bool, + device_type: str = 'cpu', + logger = None): + + self.logger = logger + self.device = device_type + self.imagination_horizon = imagination_horizon + self.actions_num = actions_num + self.is_discrete = (action_type != 'continuous') + match clip_rewards: + case 'identity': + self.reward_clipper = nn.Identity() + case 'tanh': + self.reward_clipper = nn.Tanh() + case _: + raise RuntimeError('Invalid reward clipping') + self.is_f16 = f16_precision + + self.world_model: WorldModel = torch.compile(world_model(actions_num=actions_num), mode='max-autotune').to(device_type) + self.actor: ImaginativeActor = actor(latent_dim=self.world_model.state_size, + actions_num=actions_num, + is_discrete=self.is_discrete).to(device_type) + self.critic: ImaginativeCritic = torch.compile(critic(latent_dim=self.world_model.state_size), mode='max-autotune').to(device_type) + + self.world_model_optimizer = wm_optim(model=self.world_model, scaler=self.is_f16) + # if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: + # self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor, scaler=self.is_f16) + self.actor_optimizer = actor_optim(model=self.actor) + self.critic_optimizer = critic_optim(model=self.critic) + + self.reset() + + def imagine_trajectory( + self, init_state: State, precomp_actions: t.Optional[list[Action]] = None, horizon: t.Optional[int] = None + ) -> tuple[State, torch.Tensor, torch.Tensor, + torch.Tensor]: + if horizon is None: + horizon = self.imagination_horizon + + prev_state = init_state + prev_action = torch.zeros_like(self.actor(prev_state.combined.detach()).mean) + states, actions, rewards, ts = ([init_state], + [prev_action], + [self.world_model.reward_predictor(init_state.combined).mode], + [torch.ones(prev_action.shape[:-1] + (1,), device=prev_action.device)]) + + for i in range(horizon): + if precomp_actions is not None: + a = precomp_actions[i].unsqueeze(0) + else: + a_dist = self.actor(prev_state.combined.detach()) + a = a_dist.rsample() + prior, reward, discount = self.world_model.predict_next(prev_state, a) + prev_state = prior + + states.append(prior) + rewards.append(reward) + ts.append(discount) + actions.append(a) + + return (states[0].stack(states), torch.cat(actions), torch.cat(rewards), torch.cat(ts)) + + def reset(self): + self._state = self.world_model.get_initial_state() + self._last_action = torch.zeros((1, 1, self.actions_num), device=self.device) + self._action_probs = torch.zeros((self.actions_num), device=self.device) + + def preprocess(self, rollout: Rollout): + obs = self.preprocess_obs(rollout.obs) + additional = self.world_model.precalc_data(obs.to(self.device)) + return Rollout(obs=obs, + actions=rollout.actions, + rewards=self.reward_clipper(rollout.rewards), + is_finished=rollout.is_finished, + is_first=rollout.is_first, + additional_data=rollout.additional_data | additional) + + def preprocess_obs(self, obs: torch.Tensor): + order = list(range(len(obs.shape))) + # Swap channel from last to 3 from last + order = order[:-3] + [order[-1]] + order[-3:-1] + if self.world_model.encode_vit: + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + return ToTensor(obs.type(torch.float32).permute(order) / 255.0) + else: + return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) + + def unprocess_obs(self, obs: torch.Tensor): + order = list(range(len(obs.shape))) + # # Swap channel from last to 3 from last + order = order[:-3] + order[-2:] + [order[-3]] + if self.world_model.encode_vit: + fromTensor = tv.transforms.Compose([ tv.transforms.Normalize(mean = [ 0., 0., 0. ], + std = [ 1/0.229, 1/0.224, 1/0.225 ]), + tv.transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], + std = [ 1., 1., 1. ]), + ]) + return (fromTensor(obs).clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + else: + return ((obs + 0.5).clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + # return obs.type(torch.float32).permute(order) + + def get_action(self, obs: Observation) -> Action: + obs = torch.from_numpy(obs).to(self.device) + obs = self.preprocess_obs(obs) + + self._state = self.world_model.get_latent(obs, self._last_action, self._state) + + actor_dist = self.actor.get_action(self._state) + self._last_action = actor_dist.sample() + + if self.is_discrete: + self._action_probs += actor_dist.probs.squeeze() + + if self.is_discrete: + return self._last_action.argmax() + else: + return self._last_action.squeeze().detach().cpu() + + def from_np(self, arr: np.ndarray): + arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr + return arr.to(self.device, non_blocking=True) + + def train(self, rollout_chunks: RolloutChunks): + obs, a, r, is_finished, is_first, additional = unpack(rollout_chunks) + if self.is_discrete: + a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() + discount_factors = self.critic.gamma*(1 - is_finished).float() + first_flags = is_first.float() + + # take some latent embeddings as initial + with torch.cuda.amp.autocast(enabled=self.is_f16): + losses_wm, discovered_states, metrics_wm = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags, additional) + # FIXME: wholely remove discrete RSSM + # self.world_model.recurrent_model.discretizer_scheduler.step() + + # if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: + # self.image_predictor_optimizer.step(losses_wm['loss_reconstruction_img']) + + + metrics_wm |= self.world_model_optimizer.step(losses_wm['loss_wm']) + + with torch.cuda.amp.autocast(enabled=self.is_f16): + initial_states = discovered_states.flatten().detach() + + states, actions, rewards, discount_factors = self.imagine_trajectory(initial_states) + + rewards = rewards.float() + discount_factors = discount_factors.float() + + zs = states.combined + rewards = self.world_model.reward_normalizer(rewards) + + vs = self.critic.lambda_return(zs, rewards[:-1], discount_factors) + + # Discounted factors should be shifted as they predict whether next state cannot be used + # First discount factor on contrary is always 1 as it cannot lead to trajectory finish + discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-1]], dim=0) + + # Ignore all factors after first is_finished state + discount_factors = torch.cumprod(discount_factors, dim=0).detach() + + losses_c, metrics_c = self.critic.calculate_loss(zs[:-1], vs, discount_factors[:-1]) + + # last action should be ignored as it is not used to predict next state, thus no feedback + # first value should be ignored as it is comes from replay buffer + losses_a, metrics_a = self.actor.calculate_loss(zs[:-2], + vs[1:], + self.critic.target_critic(zs[:-2]).mode, + discount_factors[:-2], + actions[1:-1]) + metrics_a |= self.actor_optimizer.step(losses_a['loss_actor']) + metrics_c |= self.critic_optimizer.step(losses_c['loss_critic']) + + self.critic.update_target() + + losses = losses_wm | losses_a | losses_c + metrics = metrics_wm | metrics_a | metrics_c + + losses = {l: val.detach().cpu().numpy() for l, val in losses.items()} + metrics = {l: val.detach().cpu().numpy() for l, val in metrics.items()} + + losses['total'] = sum(losses.values()) + return losses | metrics + + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + torch.save( + { + 'epoch': epoch_num, + 'world_model_state_dict': self.world_model.state_dict(), + 'world_model_optimizer_state_dict': self.world_model_optimizer.optimizer.state_dict(), + 'actor_state_dict': self.actor.state_dict(), + 'critic_state_dict': self.critic.state_dict(), + 'actor_optimizer_state_dict': self.actor_optimizer.optimizer.state_dict(), + 'critic_optimizer_state_dict': self.critic_optimizer.optimizer.state_dict(), + 'losses': losses + }, f'dreamerV2-{epoch_num}-{losses["total"]}.ckpt') + + def load_ckpt(self, ckpt_path: Path): + ckpt = torch.load(ckpt_path) + self.world_model.load_state_dict(ckpt['world_model_state_dict']) + # FIXME: doesn't work for optimizers + self.world_model_optimizer.load_state_dict( + ckpt['world_model_optimizer_state_dict']) + self.actor.load_state_dict(ckpt['actor_state_dict']) + self.critic.load_state_dict(ckpt['critic_state_dict']) + self.actor_optimizer.load_state_dict(ckpt['actor_optimizer_state_dict']) + self.critic_optimizer.load_state_dict(ckpt['critic_optimizer_state_dict']) + return ckpt['epoch'] diff --git a/rl_sandbox/agents/explorative_agent.py b/rl_sandbox/agents/explorative_agent.py new file mode 100644 index 0000000..444ca70 --- /dev/null +++ b/rl_sandbox/agents/explorative_agent.py @@ -0,0 +1,32 @@ +import numpy as np +from nptyping import Float, NDArray, Shape +from pathlib import Path + +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.schedulers import Scheduler +from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, + States, TerminationFlags) + + +class ExplorativeAgent(RlAgent): + def __init__(self, policy_agent: RlAgent, + exploration_agent: RlAgent, + scheduler: Scheduler): + self.policy_ag = policy_agent + self.expl_ag = exploration_agent + self.scheduler = scheduler + + def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: + if np.random.random() > self.scheduler.step(): + return self.expl_ag.get_action(obs) + return self.policy_ag.get_action(obs) + + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): + return self.expl_ag.train(s, a, r, next, is_finished) | self.policy_ag.train(s, a, r, next, is_finished) + + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + self.policy_ag.save_ckpt(epoch_num, losses) + self.expl_ag.save_ckpt(epoch_num, losses) + + def load_ckpt(self, ckpt_path: Path): + pass diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py new file mode 100644 index 0000000..0fbb1bc --- /dev/null +++ b/rl_sandbox/agents/random_agent.py @@ -0,0 +1,26 @@ +import numpy as np +import torch +from nptyping import Float, NDArray, Shape +from pathlib import Path + +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.env import Env +from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, + States, TerminationFlags) + + +class RandomAgent(RlAgent): + def __init__(self, env: Env): + self.action_space = env.action_space + + def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: + return torch.from_numpy(np.array(self.action_space.sample())) + + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): + return dict() + + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + pass + + def load_ckpt(self, ckpt_path: Path): + pass diff --git a/rl_sandbox/agents/rl_agent.py b/rl_sandbox/agents/rl_agent.py index 97c4fa7..357ec82 100644 --- a/rl_sandbox/agents/rl_agent.py +++ b/rl_sandbox/agents/rl_agent.py @@ -1,6 +1,8 @@ +from typing import Any from abc import ABCMeta, abstractmethod +from pathlib import Path -from rl_sandbox.utils.replay_buffer import Action, State, States, Actions, Rewards +from rl_sandbox.utils.replay_buffer import Action, State, States, Actions, Rewards, TerminationFlags class RlAgent(metaclass=ABCMeta): @abstractmethod @@ -8,5 +10,21 @@ def get_action(self, obs: State) -> Action: pass @abstractmethod - def train(self, s: States, a: Actions, r: Rewards, next: States): + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags) -> dict[str, Any]: + """ + Return dict with losses for logging + """ + pass + + # Some models can have internal state which should be + # properly reseted between rollouts + def reset(self): + pass + + @abstractmethod + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + pass + + @abstractmethod + def load_ckpt(self, ckpt_path: Path): pass diff --git a/rl_sandbox/config/agent/dqn.yaml b/rl_sandbox/config/agent/dqn.yaml new file mode 100644 index 0000000..dfea883 --- /dev/null +++ b/rl_sandbox/config/agent/dqn.yaml @@ -0,0 +1,5 @@ +name: dqn +_target_: rl_sandbox.agents.DqnAgent +hidden_layer_size: 16 +num_layers: 1 +discount_factor: 0.999 diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml new file mode 100644 index 0000000..5d91ce5 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -0,0 +1,70 @@ +_target_: rl_sandbox.agents.DreamerV2 + +clip_rewards: identity +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: false + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: 32 + rssm_dim: 200 + discount_loss_scale: 1.0 + kl_loss_scale: 2 + kl_loss_balancing: 0.8 + kl_free_nats: 1.00 + discrete_rssm: false + decode_vit: false + vit_l2_ratio: 0.5 + vit_img_size: 224 + encode_vit: false + predict_discount: false + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 1e-5 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.99 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: null + lr: 3e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_atari.yaml b/rl_sandbox/config/agent/dreamer_v2_atari.yaml new file mode 100644 index 0000000..8f080e2 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_atari.yaml @@ -0,0 +1,27 @@ +defaults: + - dreamer_v2 + - _self_ + +clip_rewards: tanh +layer_norm: true + +world_model: + rssm_dim: 600 + kl_loss_scale: 0.1 + discount_loss_scale: 5.0 + predict_discount: true + +actor: + entropy_scale: 1e-3 + +critic: + discount_factor: 0.999 + +wm_optim: + lr: 2e-4 + +actor_optim: + lr: 4e-5 + +critic_optim: + lr: 1e-4 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml new file mode 100644 index 0000000..8839422 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -0,0 +1,25 @@ +defaults: + - dreamer_v2 + - _self_ + +clip_rewards: tanh +layer_norm: true + +world_model: + rssm_dim: 1024 + predict_discount: true + +actor: + entropy_scale: 3e-3 + +critic: + discount_factor: 0.999 + +wm_optim: + lr: 1e-4 + +actor_optim: + lr: 1e-4 + +critic_optim: + lr: 1e-4 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml new file mode 100644 index 0000000..d9f7416 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -0,0 +1,14 @@ +defaults: + - dreamer_v2_crafter + - _self_ + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel + rssm_dim: 512 + slots_num: 6 + slots_iter_num: 2 + kl_loss_scale: 1.0 + decode_vit: true + use_prev_slots: false + vit_l2_ratio: 0.1 + encode_vit: false diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml new file mode 100644 index 0000000..2830e33 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -0,0 +1,28 @@ +defaults: + - dreamer_v2_crafter_slotted + - _self_ + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + rssm_dim: 768 + slots_num: 4 + slots_iter_num: 3 + kl_loss_scale: 1.0 + encode_vit: false + decode_vit: true + mask_combination: soft + use_prev_slots: false + per_slot_rec_loss: false + vit_l2_ratio: 0.5 + + full_qk_from: 4e4 + symmetric_qk: false + attention_block_num: 3 + + spatial_decoder: false + +wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml new file mode 100644 index 0000000..82eff7a --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -0,0 +1,22 @@ +defaults: + - dreamer_v2_crafter_slotted + - _self_ + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel + rssm_dim: 512 + slots_num: 5 + slots_iter_num: 2 + kl_loss_scale: 1.0 + encode_vit: false + decode_vit: true + mask_combination: soft + use_prev_slots: false + per_slot_rec_loss: false + vit_l2_ratio: 0.1 + +wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml new file mode 100644 index 0000000..bdce5e9 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml @@ -0,0 +1,78 @@ +_target_: rl_sandbox.agents.DreamerV2 + +clip_rewards: tanh +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: 32 + rssm_dim: 200 + slots_num: 4 + slots_iter_num: 2 + kl_loss_scale: 1000 + kl_loss_balancing: 0.8 + kl_free_nats: 0.0005 + discrete_rssm: false + decode_vit: true + vit_l2_ratio: 0.75 + use_prev_slots: false + encode_vit: false + predict_discount: false + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 1e-4 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + #- _target_: rl_sandbox.utils.optimizer.DecayScheduler + # _partial_: true + # decay_rate: 0.5 + # decay_steps: 5e5 + lr: 3e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml new file mode 100644 index 0000000..c777d19 --- /dev/null +++ b/rl_sandbox/config/config.yaml @@ -0,0 +1,66 @@ +defaults: + - agent: dreamer_v2 + - env: dm_cartpole + - training: dm + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +agent: + world_model: + _target_: rl_sandbox.agents.dreamer.world_model.WorldModel + rssm_dim: 200 + + encode_vit: false + decode_vit: false + #vit_l2_ratio: 1.0 + + #kl_loss_scale: 2.0 + #kl_loss_balancing: 0.8 + #kl_free_nats: 1.0 + + #wm_optim: + # lr_scheduler: + # - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + # _partial_: true + # warmup_steps: 1e3 + +logger: + message: Default dreamer fp16 + log_grads: false + +training: + f16_precision: true + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + #- _target_: rl_sandbox.metrics.PostSlottedDreamerMetricsEvaluator + # _partial_: true + #- _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator + # _partial_: true + +debug: + profiler: false + +hydra: + mode: MULTIRUN + #mode: RUN + launcher: + n_jobs: 3 + sweeper: + params: + seed: 17,42,45 + env: dm_finger_spin,dm_finger_turn_hard diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml new file mode 100644 index 0000000..96b01e5 --- /dev/null +++ b/rl_sandbox/config/config_attention.yaml @@ -0,0 +1,66 @@ +defaults: + - agent: dreamer_v2_slotted_attention + - env: crafter + - training: crafter + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda:0 + +agent: + world_model: + encode_vit: false + decode_vit: true + vit_img_size: 224 + vit_l2_ratio: 1.0 + slots_iter_num: 3 + slots_num: 4 + kl_loss_scale: 3.0 + kl_loss_balancing: 0.6 + kl_free_nats: 1.0 + + actor_optim: + lr: 1e-4 + + critic_optim: + lr: 1e-4 + +logger: + message: Attention, only dino, kl=0.6/3, 14x14, 768 rssm, no fp16, reverse dino + log_grads: false + +training: + f16_precision: false + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + #- _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + - _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + # n_jobs: 8 + #sweeper: + # params: + # agent.world_model.full_qk_from: 1,2e4 + # agent.world_model.symmetric_qk: true,false + # agent.world_model.attention_block_num: 1,3 diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml new file mode 100644 index 0000000..f874214 --- /dev/null +++ b/rl_sandbox/config/config_combined.yaml @@ -0,0 +1,55 @@ +defaults: + - agent: dreamer_v2_slotted_combined + - env: crafter + - logger: wandb + - training: crafter + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda:1 + +agent: + world_model: + encode_vit: false + decode_vit: false + #vit_img_size: 224 + #vit_l2_ratio: 0.5 + slots_iter_num: 3 + slots_num: 6 + kl_loss_scale: 1.0 + kl_free_nats: 1.0 + +logger: + message: Combined, without dino, added pos encoding for reconstruction + log_grads: false + +training: + checkpoint_path: null + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #sweeper: + # params: + # agent.world_model.kl_loss_scale: 1e2,1e1 + # agent.world_model.slots_num: 3,6 + # agent.world_model.per_slot_rec_loss: true + # agent.world_model.mask_combination: soft,hard + # agent.world_model.vit_l2_ratio: 0.1,1e-3 diff --git a/rl_sandbox/config/config_default.yaml b/rl_sandbox/config/config_default.yaml new file mode 100644 index 0000000..0607307 --- /dev/null +++ b/rl_sandbox/config/config_default.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2_crafter + - env: crafter + - training: crafter + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + message: Crafter default + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #sweeper: + # params: + # agent.world_model._target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel,rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + # agent.world_model.vit_l2_ratio: 0.1,0.5 + # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + # agent.world_model.vit_l2_ratio: 0.1,0.9 diff --git a/rl_sandbox/config/config_dino.yaml b/rl_sandbox/config/config_dino.yaml new file mode 100644 index 0000000..569fba2 --- /dev/null +++ b/rl_sandbox/config/config_dino.yaml @@ -0,0 +1,43 @@ +defaults: + - agent: dreamer_v2 + - env: dm_quadruped + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Quadruped with DINO features + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 1e4 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 8 + n_jobs: 1 + sweeper: + #params: + # agent.world_model.kl_loss_scale: 1e-4,1e-3,1e-2,0.1,1.0,1e2,1e3,1e4 diff --git a/rl_sandbox/config/config_dino_1.yaml b/rl_sandbox/config/config_dino_1.yaml new file mode 100644 index 0000000..3c082c1 --- /dev/null +++ b/rl_sandbox/config/config_dino_1.yaml @@ -0,0 +1,43 @@ +defaults: + - agent: dreamer_v2 + - env: dm_cheetah + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Cheetah with DINO features, 0.75 ratio + log_grads: false + +training: + checkpoint_path: null + steps: 2e6 + val_logs_every: 2e4 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 8 + n_jobs: 1 + #sweeper: + #params: + # agent.world_model.kl_loss_scale: 1e-4,1e-3,1e-2,0.1,1.0,1e2,1e3,1e4 diff --git a/rl_sandbox/config/config_dino_2.yaml b/rl_sandbox/config/config_dino_2.yaml new file mode 100644 index 0000000..fd42a77 --- /dev/null +++ b/rl_sandbox/config/config_dino_2.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2 + - env: dm_acrobot + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Acrobot default + log_grads: false + +training: + checkpoint_path: null + steps: 2e6 + val_logs_every: 2e4 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 8 + n_jobs: 1 + #sweeper: + #params: + # agent.world_model.kl_loss_scale: 1e-4,1e-3,1e-2,0.1,1.0,1e2,1e3,1e4 + diff --git a/rl_sandbox/config/config_postslot.yaml b/rl_sandbox/config/config_postslot.yaml new file mode 100644 index 0000000..85b4b9f --- /dev/null +++ b/rl_sandbox/config/config_postslot.yaml @@ -0,0 +1,71 @@ +defaults: + - agent: dreamer_v2 + - env: dm_acrobot + - training: dm + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +agent: + world_model: + _target_: rl_sandbox.agents.dreamer.world_model_post_slot.WorldModel + rssm_dim: 256 + slots_num: 5 + slots_iter_num: 3 + + encode_vit: false + decode_vit: false + mask_combination: soft + vit_l2_ratio: 1.0 + + vit_img_size: 224 + kl_loss_scale: 1.0 + kl_loss_balancing: 0.8 + kl_free_nats: 1.0 + + use_reshuffle: true + per_slot_rec_loss: false + spatial_decoder: false + + wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + +logger: + message: Post-wm slot attention, n_dim=256 + log_grads: false + +training: + f16_precision: true + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.PostSlottedDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + mode: MULTIRUN + #mode: RUN + launcher: + n_jobs: 3 + sweeper: + params: + seed: 17,42,45 + env: dm_finger_spin,dm_finger_turn_hard + diff --git a/rl_sandbox/config/config_postslot_dino.yaml b/rl_sandbox/config/config_postslot_dino.yaml new file mode 100644 index 0000000..94e9a11 --- /dev/null +++ b/rl_sandbox/config/config_postslot_dino.yaml @@ -0,0 +1,71 @@ +defaults: + - agent: dreamer_v2 + - env: dm_acrobot + - training: dm + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +agent: + world_model: + _target_: rl_sandbox.agents.dreamer.world_model_post_slot.WorldModel + rssm_dim: 256 + slots_num: 5 + slots_iter_num: 3 + + encode_vit: false + decode_vit: true + mask_combination: soft + vit_l2_ratio: 1.0 + + vit_img_size: 224 + kl_loss_scale: 1.0 + kl_loss_balancing: 0.8 + kl_free_nats: 1.0 + + use_reshuffle: true + per_slot_rec_loss: false + spatial_decoder: false + + wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + +logger: + message: Post-wm dino slot attention, n_dim=256 + log_grads: false + +training: + f16_precision: true + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.PostSlottedDinoDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 3 + sweeper: + params: + seed: 17,42,45 + env: dm_finger_spin,dm_finger_turn_hard + diff --git a/rl_sandbox/config/config_slotted.yaml b/rl_sandbox/config/config_slotted.yaml new file mode 100644 index 0000000..0a218e4 --- /dev/null +++ b/rl_sandbox/config/config_slotted.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2_slotted_debug + - env: dm_cartpole + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Cartpole with slot attention, 1e3 kl, 2 iter num, free nats + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 1e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #sweeper: + # params: + # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + # agent.world_model.vit_l2_ratio: 0.1,0.9 + + diff --git a/rl_sandbox/config/config_slotted_debug.yaml b/rl_sandbox/config/config_slotted_debug.yaml new file mode 100644 index 0000000..6388941 --- /dev/null +++ b/rl_sandbox/config/config_slotted_debug.yaml @@ -0,0 +1,52 @@ +defaults: + - agent: dreamer_v2_slotted_combined + - env: crafter + - logger: tensorboard + - training: crafter + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +agent: + world_model: + encode_vit: true + vit_l2_ratio: 1.0 + kl_loss_scale: 1e4 + +logger: + message: Combined encode vit + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #sweeper: + # params: + # agent.world_model.slots_num: 3,6 + # agent.world_model.per_slot_rec_loss: true + # agent.world_model.mask_combination: soft,hard + # agent.world_model.kl_loss_scale: 1e2 + # agent.world_model.vit_l2_ratio: 0.1,1e-3 + diff --git a/rl_sandbox/config/env/atari.yaml b/rl_sandbox/config/env/atari.yaml new file mode 100644 index 0000000..e4ac2d4 --- /dev/null +++ b/rl_sandbox/config/env/atari.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.AtariEnv +task_name: daemon_attack +sticky: true +obs_res: [64, 64] +repeat_action_num: 1 +life_done: false +greyscale: false +transforms: [] diff --git a/rl_sandbox/config/env/atari_amidar.yaml b/rl_sandbox/config/env/atari_amidar.yaml new file mode 100644 index 0000000..657132c --- /dev/null +++ b/rl_sandbox/config/env/atari_amidar.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: amidar diff --git a/rl_sandbox/config/env/atari_asterix.yaml b/rl_sandbox/config/env/atari_asterix.yaml new file mode 100644 index 0000000..8618320 --- /dev/null +++ b/rl_sandbox/config/env/atari_asterix.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: asterix diff --git a/rl_sandbox/config/env/atari_chopper_command.yaml b/rl_sandbox/config/env/atari_chopper_command.yaml new file mode 100644 index 0000000..12ced33 --- /dev/null +++ b/rl_sandbox/config/env/atari_chopper_command.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: chopper_command diff --git a/rl_sandbox/config/env/atari_demon_attack.yaml b/rl_sandbox/config/env/atari_demon_attack.yaml new file mode 100644 index 0000000..3239984 --- /dev/null +++ b/rl_sandbox/config/env/atari_demon_attack.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: demon_attack diff --git a/rl_sandbox/config/env/atari_freeway.yaml b/rl_sandbox/config/env/atari_freeway.yaml new file mode 100644 index 0000000..9e1555c --- /dev/null +++ b/rl_sandbox/config/env/atari_freeway.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: freeway diff --git a/rl_sandbox/config/env/atari_private_eye.yaml b/rl_sandbox/config/env/atari_private_eye.yaml new file mode 100644 index 0000000..67d16a6 --- /dev/null +++ b/rl_sandbox/config/env/atari_private_eye.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: private_eye diff --git a/rl_sandbox/config/env/atari_venture.yaml b/rl_sandbox/config/env/atari_venture.yaml new file mode 100644 index 0000000..f39acc3 --- /dev/null +++ b/rl_sandbox/config/env/atari_venture.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: venture diff --git a/rl_sandbox/config/env/atari_video_pinball.yaml b/rl_sandbox/config/env/atari_video_pinball.yaml new file mode 100644 index 0000000..7e2b8dc --- /dev/null +++ b/rl_sandbox/config/env/atari_video_pinball.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: video_pinball diff --git a/rl_sandbox/config/env/crafter.yaml b/rl_sandbox/config/env/crafter.yaml new file mode 100644 index 0000000..822ac27 --- /dev/null +++ b/rl_sandbox/config/env/crafter.yaml @@ -0,0 +1,6 @@ +_target_: rl_sandbox.utils.env.GymEnv +task_name: CrafterReward-v1 +run_on_pixels: false # it is run on pixels by default +obs_res: [64, 64] +repeat_action_num: 1 +transforms: [] diff --git a/rl_sandbox/config/env/dm_acrobot.yaml b/rl_sandbox/config/env/dm_acrobot.yaml new file mode 100644 index 0000000..313a1f6 --- /dev/null +++ b/rl_sandbox/config/env/dm_acrobot.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: acrobot +task_name: swingup +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_cartpole.yaml b/rl_sandbox/config/env/dm_cartpole.yaml new file mode 100644 index 0000000..5cb6345 --- /dev/null +++ b/rl_sandbox/config/env/dm_cartpole.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: cartpole +task_name: swingup +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_cheetah.yaml b/rl_sandbox/config/env/dm_cheetah.yaml new file mode 100644 index 0000000..a6d1490 --- /dev/null +++ b/rl_sandbox/config/env/dm_cheetah.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: cheetah +task_name: run +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_finger_spin.yaml b/rl_sandbox/config/env/dm_finger_spin.yaml new file mode 100644 index 0000000..a4b8f9f --- /dev/null +++ b/rl_sandbox/config/env/dm_finger_spin.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: finger +task_name: spin +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_finger_turn_easy.yaml b/rl_sandbox/config/env/dm_finger_turn_easy.yaml new file mode 100644 index 0000000..bbc6de7 --- /dev/null +++ b/rl_sandbox/config/env/dm_finger_turn_easy.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: finger +task_name: turn_easy +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_finger_turn_hard.yaml b/rl_sandbox/config/env/dm_finger_turn_hard.yaml new file mode 100644 index 0000000..b040df1 --- /dev/null +++ b/rl_sandbox/config/env/dm_finger_turn_hard.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: finger +task_name: turn_hard +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_hopper_hop.yaml b/rl_sandbox/config/env/dm_hopper_hop.yaml new file mode 100644 index 0000000..ff8998f --- /dev/null +++ b/rl_sandbox/config/env/dm_hopper_hop.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: hopper +task_name: hop +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_manipulator.yaml b/rl_sandbox/config/env/dm_manipulator.yaml new file mode 100644 index 0000000..6cbeea6 --- /dev/null +++ b/rl_sandbox/config/env/dm_manipulator.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: manipulator +task_name: bring_ball +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_quadruped.yaml b/rl_sandbox/config/env/dm_quadruped.yaml new file mode 100644 index 0000000..9f73398 --- /dev/null +++ b/rl_sandbox/config/env/dm_quadruped.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: quadruped +task_name: run +run_on_pixels: true +obs_res: [64, 64] +camera_id: 2 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_quadruped_walk.yaml b/rl_sandbox/config/env/dm_quadruped_walk.yaml new file mode 100644 index 0000000..aa5e541 --- /dev/null +++ b/rl_sandbox/config/env/dm_quadruped_walk.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: quadruped +task_name: walk +run_on_pixels: true +obs_res: [64, 64] +camera_id: 2 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_reacher_hard.yaml b/rl_sandbox/config/env/dm_reacher_hard.yaml new file mode 100644 index 0000000..6ecbd96 --- /dev/null +++ b/rl_sandbox/config/env/dm_reacher_hard.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: reacher +task_name: hard +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_walker.yaml b/rl_sandbox/config/env/dm_walker.yaml new file mode 100644 index 0000000..c97057c --- /dev/null +++ b/rl_sandbox/config/env/dm_walker.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: walker +task_name: run +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_walker_stand.yaml b/rl_sandbox/config/env/dm_walker_stand.yaml new file mode 100644 index 0000000..ff2f83a --- /dev/null +++ b/rl_sandbox/config/env/dm_walker_stand.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: walker +task_name: stand +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/mock.yaml b/rl_sandbox/config/env/mock.yaml new file mode 100644 index 0000000..c62296c --- /dev/null +++ b/rl_sandbox/config/env/mock.yaml @@ -0,0 +1,5 @@ +_target_: rl_sandbox.utils.env.MockEnv +run_on_pixels: true +obs_res: [64, 64] +repeat_action_num: 5 +transforms: [] diff --git a/rl_sandbox/config/logger/tensorboard.yaml b/rl_sandbox/config/logger/tensorboard.yaml new file mode 100644 index 0000000..8540962 --- /dev/null +++ b/rl_sandbox/config/logger/tensorboard.yaml @@ -0,0 +1 @@ +type: tensorboard diff --git a/rl_sandbox/config/logger/wandb.yaml b/rl_sandbox/config/logger/wandb.yaml new file mode 100644 index 0000000..d05c8be --- /dev/null +++ b/rl_sandbox/config/logger/wandb.yaml @@ -0,0 +1,2 @@ +type: wandb +project: slotted_dreamer diff --git a/rl_sandbox/config/training/atari.yaml b/rl_sandbox/config/training/atari.yaml new file mode 100644 index 0000000..1aa0a6f --- /dev/null +++ b/rl_sandbox/config/training/atari.yaml @@ -0,0 +1,9 @@ +steps: 4e4 +prefill: 50000 +batch_size: 16 +f16_precision: false +pretrain: 1 +prioritize_ends: true +train_every: 16 +save_checkpoint_every: 5e5 +val_logs_every: 2e4 diff --git a/rl_sandbox/config/training/crafter.yaml b/rl_sandbox/config/training/crafter.yaml new file mode 100644 index 0000000..ba1943f --- /dev/null +++ b/rl_sandbox/config/training/crafter.yaml @@ -0,0 +1,9 @@ +steps: 1e6 +prefill: 10000 +batch_size: 16 +f16_precision: false +pretrain: 1 +prioritize_ends: true +train_every: 5 +save_checkpoint_every: 5e5 +val_logs_every: 2e4 diff --git a/rl_sandbox/config/training/dm.yaml b/rl_sandbox/config/training/dm.yaml new file mode 100644 index 0000000..67fd27f --- /dev/null +++ b/rl_sandbox/config/training/dm.yaml @@ -0,0 +1,9 @@ +steps: 1e6 +prefill: 1000 +batch_size: 50 +pretrain: 100 +prioritize_ends: false +train_every: 4 +save_checkpoint_every: 2e6 +val_logs_every: 2e4 +f16_precision: false diff --git a/rl_sandbox/crafter_metrics.py b/rl_sandbox/crafter_metrics.py new file mode 100644 index 0000000..a1ba450 --- /dev/null +++ b/rl_sandbox/crafter_metrics.py @@ -0,0 +1,78 @@ +import json +import pathlib +import warnings +import collections +from pathlib import Path + +import numpy as np + +from rl_sandbox.utils.replay_buffer import Rollout + +def compute_scores(percents): + # Geometric mean with an offset of 1%. + assert (0 <= percents).all() and (percents <= 100).all() + if (percents <= 1.0).all(): + print('Warning: The input may not be in the right range.') + with warnings.catch_warnings(): # Empty seeds become NaN. + warnings.simplefilter('ignore', category=RuntimeWarning) + scores = np.exp(np.nanmean(np.log(1 + percents), -1)) - 1 + return scores + + +def load_stats(filename, budget): + steps = 0 + rewards = [] + lengths = [] + achievements = collections.defaultdict(list) + for line in filename.read_text().split('\n'): + if not line.strip(): + continue + episode = json.loads(line) + steps += episode['length'] + if steps > budget: + break + lengths.append(episode['length']) + for key, value in episode.items(): + if key.startswith('achievement_'): + achievements[key].append(value) + unlocks = int(np.sum([(v[-1] >= 1) for v in achievements.values()])) + health = -0.9 + rewards.append(unlocks + health) + return rewards, lengths, achievements + + +class CrafterMetricsEvaluator(): + def __init__(self, agent: 'DreamerV2'): + self.agent = agent + self.episode = 0 + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + if logger.log_dir() is None: + return + budget = 1e6 + stats_file = Path(logger.log_dir()) / "stats.jsonl" + _, lengths, achievements = load_stats(stats_file, budget) + + tasks = list(achievements.keys()) + + xs = np.cumsum(lengths).tolist() + episodes = (np.array(xs) <= budget).sum() + percents = np.empty((len(achievements))) + percents[:] = np.nan + for key, values in achievements.items(): + k = tasks.index(key) + percent = 100 * (np.array(values[:episodes]) >= 1).mean() + percents[k] = percent + + score = compute_scores(percents) + + logger.log({"score": score}, global_step, mode='val') + + def on_step(self, logger): + pass + + def on_episode(self, logger, rollout, global_step: int): + pass + + + diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py new file mode 100644 index 0000000..ce7f5ac --- /dev/null +++ b/rl_sandbox/metrics.py @@ -0,0 +1,518 @@ +import numpy as np +import matplotlib.pyplot as plt +import torchvision as tv +from torch.nn import functional as F +import torch + +from rl_sandbox.utils.replay_buffer import Rollout +from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, + Observations, Rewards, + TerminationFlags, IsFirstFlags) + + +class EpisodeMetricsEvaluator(): + def __init__(self, agent: 'DreamerV2', log_video: bool = False): + self.agent = agent + self.episode = 0 + self.log_video = log_video + + def on_step(self, logger): + pass + + def on_episode(self, logger, rollout, global_step: int): + self.episode += 1 + + metrics = self.calculate_metrics([rollout]) + logger.log(metrics, global_step, mode='train') + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + metrics = self.calculate_metrics(rollouts) + logger.log(metrics, global_step, mode='val') + if self.log_video: + video = rollouts[0].obs.unsqueeze(0) + logger.add_video('val/visualization', self.agent.unprocess_obs(video), global_step) + + def calculate_metrics(self, rollouts: list[Rollout]): + return { + 'episode_len': self._episode_duration(rollouts), + 'episode_return': self._episode_return(rollouts) + } + + def _episode_duration(self, rollouts: list[Rollout]): + return np.mean(list(map(lambda x: len(x.obs), rollouts))) + + def _episode_return(self, rollouts: list[Rollout]): + return np.mean(list(map(lambda x: sum(x.rewards), rollouts))) + +class DreamerMetricsEvaluator(): + def __init__(self, agent: 'DreamerV2'): + self.agent = agent + self.stored_steps = 0 + self.episode = 0 + + if agent.is_discrete: + pass + + self.reset_ep() + + def reset_ep(self): + self._latent_probs = torch.zeros((self.agent.world_model.latent_classes, self.agent.world_model.latent_dim), device=self.agent.device) + self._action_probs = torch.zeros((self.agent.actions_num), device=self.agent.device) + self.stored_steps = 0 + + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state.stoch_dist.base_dist.probs.squeeze(0).mean(dim=0) + + def on_episode(self, logger, rollout, global_step: int): + latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() + self.latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) + self.action_hist = (self.agent._action_probs / self.stored_steps).detach().cpu().numpy() + + self.reset_ep() + self.episode += 1 + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + self.viz_log(rollouts[0], logger, global_step) + + if self.episode == 0: + return + + # if discrete action space + if self.agent.is_discrete: + fig = plt.Figure() + ax = fig.add_axes([0, 0, 1, 1]) + ax.bar(np.arange(self.agent.actions_num), self.action_hist) + logger.add_figure('val/action_probs', fig, self.episode) + else: + # log mean +- std + pass + logger.add_image('val/latent_probs', np.expand_dims(self.latent_hist, 0), global_step, dataformats='HW') + logger.add_image('val/latent_probs_sorted', np.expand_dims(np.sort(self.latent_hist, axis=1), 0), global_step, dataformats='HW') + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + # obs = self.agent.preprocess_obs(obs) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + rews = [] + + state = None + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) + video_r = self.agent.world_model.image_predictor(state.combined).mode + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + + video.append(self.agent.unprocess_obs(video_r)) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + video_r = self.agent.world_model.image_predictor(states.combined[1:]).mode.detach() + + video.append(self.agent.unprocess_obs(video_r)) + + return torch.cat(video), rews + + def viz_log(self, rollout, logger, epoch_num): + rollout = rollout.to(device=self.agent.device) + init_indeces = np.random.choice(len(rollout.obs) - self.agent.imagination_horizon, 5) + + videos = torch.cat([ + rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces + ], dim=3) + videos = self.agent.unprocess_obs(videos) + + real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] + + videos_r, imagined_rewards = zip(*[self._generate_video(obs_0, a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + ]) + videos_r = torch.cat(videos_r, dim=3) + + videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) + + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) + +class SlottedDreamerMetricsEvaluator(DreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze().mean(dim=0) + + def on_episode(self, logger, rollout, global_step: int): + wm = self.agent.world_model + + mu = wm.slot_attention.slots_mu + sigma = wm.slot_attention.slots_logsigma.exp() + self.mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + self.sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + + + super().on_episode(logger, rollout, global_step) + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + super().on_val(logger, rollouts, global_step) + + if self.episode == 0: + return + + wm = self.agent.world_model + + if hasattr(wm.recurrent_model, 'last_attention'): + logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention.unsqueeze(0), global_step, dataformats='HW') + # logger.add_image('val/embed_attention', wm.recurrent_model.embed_attn.unsqueeze(0), global_step, dataformats='HW') + + logger.add_image('val/slot_attention_mu', (self.mu_hist/self.mu_hist.max()).unsqueeze(0), global_step, dataformats='HW') + logger.add_image('val/slot_attention_sigma', (self.sigma_hist/self.sigma_hist.max()).unsqueeze(0), global_step, dataformats='HW') + + logger.add_scalar('val/slot_attention_mu_diff_max', self.mu_hist.max(), global_step) + logger.add_scalar('val/slot_attention_sigma_diff_max', self.sigma_hist.max(), global_step) + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + rews = [] + + state = None + prev_slots = None + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode + + decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode + decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + + return torch.cat(video), rews, torch.cat(slots_video) + + def viz_log(self, rollout, logger, epoch_num): + rollout = rollout.to(device=self.agent.device) + init_indeces = np.random.choice(len(rollout.obs) - self.agent.imagination_horizon, 5) + + videos = torch.cat([ + rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces + ], dim=3) + videos = self.agent.unprocess_obs(videos) + + real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] + + videos_r, imagined_rewards, slots_video = zip(*[self._generate_video(obs_0, a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + ]) + videos_r = torch.cat(videos_r, dim=3) + + videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) + slots_video = torch.cat(list(slots_video)[:3], dim=3) + + slots_video = slots_video.permute((0, 2, 3, 1, 4)) + slots_video = slots_video.reshape(*slots_video.shape[:-2], -1).unsqueeze(0) + + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + logger.add_video('val/dreamed_slots', slots_video, epoch_num) + + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) + +class SlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def _generate_video(self, obs: list[Observation], actions: list[Action], d_feats: list[torch.Tensor], update_num: int): + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + vit_slots_video = [] + vit_mean_err_video = [] + vit_max_err_video = [] + rews = [] + + vit_size = self.agent.world_model.vit_size + + state = None + prev_slots = None + for idx, (o, a, d_feat) in enumerate(list(zip(obs, actions, d_feats))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode + + decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = F.softmax(vit_masks, dim=1) + decoded_dino = (decoded_dino_feats * vit_mask).sum(dim=1) + upscale = tv.transforms.Resize(64, antialias=True) + + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + vit_slots_video.append(self.agent.unprocess_obs(per_slot_vit/upscaled_mask.max())) + vit_mean_err_video.append(((d_feat.reshape(decoded_dino.shape) - decoded_dino)**2).mean(dim=1)) + vit_max_err_video.append(((d_feat.reshape(decoded_dino.shape) - decoded_dino)**2).max(dim=1).values) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = F.softmax(vit_masks, dim=1) + decoded_dino = (decoded_dino_feats * vit_mask).sum(dim=1) + + upscale = tv.transforms.Resize(64, antialias=True) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * obs[update_num+1:].to(self.agent.device).unsqueeze(1)) + + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + vit_slots_video.append(self.agent.unprocess_obs(per_slot_vit/torch.amax(upscaled_mask, dim=(1,2,3)).view(-1, 1, 1, 1, 1))) + vit_mean_err_video.append(((d_feats[update_num+1:].reshape(decoded_dino.shape) - decoded_dino)**2).mean(dim=1)) + vit_max_err_video.append(((d_feats[update_num+1:].reshape(decoded_dino.shape) - decoded_dino)**2).max(dim=1).values) + + return torch.cat(video), rews, torch.cat(slots_video), torch.cat(vit_slots_video), torch.cat(vit_mean_err_video).unsqueeze(0), torch.cat(vit_max_err_video).unsqueeze(0) + + def viz_log(self, rollout, logger, epoch_num): + rollout = rollout.to(device=self.agent.device) + init_indeces = np.random.choice(len(rollout.obs) - self.agent.imagination_horizon, 5) + + videos = torch.cat([ + rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces + ], dim=3) + videos = self.agent.unprocess_obs(videos) + + real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] + + videos_r, imagined_rewards, slots_video, vit_masks_video, vit_mean_err_video, vit_max_err_video = zip(*[self._generate_video(obs_0, a_0, d_feat_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0, d_feat_0 in zip( + [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.additional_data['d_features'][idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + ]) + videos_r = torch.cat(videos_r, dim=3) + + vit_mean_err_video = torch.cat(vit_mean_err_video, dim=3) + vit_max_err_video = torch.cat(vit_max_err_video, dim=3) + vit_mean_err_video = (vit_mean_err_video/vit_mean_err_video.max() * 255.0).to(dtype=torch.uint8) + vit_max_err_video = (vit_max_err_video/vit_max_err_video.max() * 255.0).to(dtype=torch.uint8) + + videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) + + slots_video = torch.cat(list(slots_video)[:3], dim=3) + slots_video = slots_video.permute((0, 2, 3, 1, 4)) + slots_video = slots_video.reshape(*slots_video.shape[:-2], -1).unsqueeze(0) + + vit_masks_video = torch.cat(list(vit_masks_video)[:3], dim=3) + vit_masks_video = vit_masks_video.permute((0, 2, 3, 1, 4)) + vit_masks_video = vit_masks_video.reshape(*vit_masks_video.shape[:-2], -1).unsqueeze(0) + + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + logger.add_video('val/dreamed_slots', slots_video, epoch_num) + logger.add_video('val/dreamed_vit_masks', vit_masks_video, epoch_num) + logger.add_video('val/vit_mean_err', vit_mean_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) + logger.add_video('val/vit_max_err', vit_max_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) + + # FIXME: rewrite sum(...) as (...).sum() + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) + +class PostSlottedDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze() + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + rews = [] + + state = None + prev_slots = None + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode + + wm_state = self.agent.world_model.state_reshuffle(state.combined) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + decoded_imgs, masks = self.agent.world_model.image_predictor(wm_state_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + wm_state = self.agent.world_model.state_reshuffle(states.combined[1:]) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode + decoded_imgs, masks = self.agent.world_model.image_predictor(wm_state_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + + return torch.cat(video), rews, torch.cat(slots_video) + + +class PostSlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze() + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + rews = [] + + vit_size = self.agent.world_model.vit_size + + state = None + prev_slots = None + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode + + wm_state = self.agent.world_model.state_reshuffle(state.combined) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + # decoded_imgs, masks = self.agent.world_model.image_predictor(wm_state_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + # img_mask = self.agent.world_model.slot_mask(masks) + # decoded_imgs = decoded_imgs * img_mask + # video_r = torch.sum(decoded_imgs, dim=1) + + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(wm_state_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = self.agent.world_model.slot_mask(vit_masks) + decoded_dino_feats = decoded_dino_feats * vit_mask + decoded_dino = (decoded_dino_feats).sum(dim=1) + upscale = tv.transforms.Resize(64, antialias=True) + + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + video.append(self.agent.unprocess_obs(o).unsqueeze(0)) + slots_video.append(self.agent.unprocess_obs(per_slot_vit)) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + wm_state = self.agent.world_model.state_reshuffle(states.combined[1:]) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(wm_state_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = F.softmax(vit_masks, dim=1) + decoded_dino = (decoded_dino_feats * vit_mask).sum(dim=1) + + upscale = tv.transforms.Resize(64, antialias=True) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * obs[update_num+1:].to(self.agent.device).unsqueeze(1)) + + video.append(self.agent.unprocess_obs(obs[update_num+1:])) + slots_video.append(self.agent.unprocess_obs(per_slot_vit)) + + return torch.cat(video), rews, torch.cat(slots_video) diff --git a/rl_sandbox/test/dreamer/test_critic.py b/rl_sandbox/test/dreamer/test_critic.py new file mode 100644 index 0000000..6d90ea5 --- /dev/null +++ b/rl_sandbox/test/dreamer/test_critic.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from rl_sandbox.agents.dreamer_v2 import ImaginativeCritic + +@pytest.fixture +def imaginative_critic(): + return ImaginativeCritic(discount_factor=1, + update_interval=100, + soft_update_fraction=1, + value_target_lambda=0.95, + latent_dim=10) + +def test_lambda_return_discount_0(imaginative_critic): + # Should just return rewards if discount_factor is 0 + imaginative_critic.lambda_ = 0 + imaginative_critic.gamma = 0 + rs = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + vs = torch.ones_like(rs) + ts = torch.ones_like(rs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == rs) + +def test_lambda_return_lambda_0(imaginative_critic): + # Should return 1-step return if lambda is 0 + imaginative_critic.lambda_ = 0 + imaginative_critic.gamma = 1 + vs = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + rs = torch.ones_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])) + +def test_lambda_return_lambda_0_gamma_0_5(imaginative_critic): + # Should return 1-step return if lambda is 0 + imaginative_critic.lambda_ = 0 + imaginative_critic.gamma = 0.5 + vs = torch.Tensor([2, 2, 4, 4, 6, 6, 8, 8, 10, 10]) + rs = torch.ones_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([2, 2, 3, 3, 4, 4, 5, 5, 6, 6])) + +def test_lambda_return_lambda_1(imaginative_critic): + # Should return Monte-Carlo return + imaginative_critic.lambda_ = 1 + imaginative_critic.gamma = 1 + vs = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + rs = torch.ones_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([20, 19, 18, 17, 16, 15, 14, 13, 12, 11])) + +def test_lambda_return_lambda_1_gamma_0_5(imaginative_critic): + # Should return Monte-Carlo return + imaginative_critic.lambda_ = 1 + imaginative_critic.gamma = 0.5 + vs = torch.Tensor([2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]) + rs = torch.zeros_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([1, 2, 4, 8, 16, 32, 64, 128, 256, 512])) diff --git a/rl_sandbox/test/test_linear_scheduler.py b/rl_sandbox/test/test_linear_scheduler.py new file mode 100644 index 0000000..e39a66d --- /dev/null +++ b/rl_sandbox/test/test_linear_scheduler.py @@ -0,0 +1,15 @@ +from rl_sandbox.utils.schedulers import LinearScheduler + +def test_linear_schedule(): + s = LinearScheduler(0, 10, 5) + assert s.step() == 0 + assert s.step() == 2.5 + assert s.step() == 5 + assert s.step() == 7.5 + assert s.step() == 10.0 + +def test_linear_schedule_after(): + s = LinearScheduler(0, 10, 5) + for _ in range(5): + s.step() + assert s.step() == 10.0 diff --git a/rl_sandbox/test/test_replay_buffer.py b/rl_sandbox/test/test_replay_buffer.py new file mode 100644 index 0000000..392db93 --- /dev/null +++ b/rl_sandbox/test/test_replay_buffer.py @@ -0,0 +1,87 @@ +import numpy as np +from pytest import fixture + +from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout + + +@fixture +def rep_buf(): + return ReplayBuffer() + + +def test_creation(rep_buf: ReplayBuffer): + assert len(rep_buf) == 0 + + +def test_adding(rep_buf: ReplayBuffer): + s = np.ones((3, 8)) + a = np.ones((3, 3), dtype=np.int32) + r = np.ones((3)) + n = np.ones((3, 8)) + f = np.zeros((3), dtype=np.bool8) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) + + assert len(rep_buf) == 3 + + s = np.zeros((3, 8)) + a = np.zeros((3, 3), dtype=np.int32) + r = np.zeros((3)) + n = np.zeros((3, 8)) + f = np.zeros((3), dtype=np.bool8) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) + + assert len(rep_buf) == 6 + + +def test_can_sample(rep_buf: ReplayBuffer): + assert rep_buf.can_sample(1) == False + + s = np.ones((3, 8)) + a = np.zeros((3, 3), dtype=np.int32) + r = np.ones((3)) + n = np.zeros((3, 8)) + f = np.zeros((3), dtype=np.bool8) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) + + assert rep_buf.can_sample(5) == False + assert rep_buf.can_sample(1) == True + + rep_buf.add_rollout(Rollout(s, a, r, n, f)) + + assert rep_buf.can_sample(5) == True + + +def test_sampling(rep_buf: ReplayBuffer): + for i in range(5): + rep_buf.add_rollout( + Rollout(np.ones((1, 3)), np.ones((1, 2), dtype=np.int32), i * np.ones((1)), + np.ones((3, 8)), np.zeros((3), dtype=np.bool8))) + + np.random.seed(42) + _, _, r, _, _ = rep_buf.sample(3) + assert (r == [1, 4, 3]).all() + + +def test_cluster_sampling(rep_buf: ReplayBuffer): + for i in range(5): + rep_buf.add_rollout( + Rollout(np.stack([np.arange(3, dtype=np.float32) for _ in range(3)]).T, + np.ones((3, 2), dtype=np.int32), i * np.ones((3)), + np.stack([np.arange(1, 4, dtype=np.float32) for _ in range(3)]).T, + np.zeros((3), dtype=np.bool8))) + + np.random.seed(42) + s, _, r, n, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [1, 1, 4, 4]).all() + assert (s[:, 0] == [0, 1, 1, 2]).all() + assert (n[:, 0] == [1, 2, 2, 3]).all() + + s, _, r, n, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [2, 2, 0, 0]).all() + assert (s[:, 0] == [0, 1, 0, 1]).all() + assert (n[:, 0] == [1, 2, 1, 2]).all() + + s, _, r, n, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [0, 0, 4, 4]).all() + assert (s[:, 0] == [1, 2, 1, 2]).all() + assert (n[:, 0] == [2, 3, 2, 3]).all() diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py new file mode 100644 index 0000000..caca930 --- /dev/null +++ b/rl_sandbox/train.py @@ -0,0 +1,163 @@ +import random +import os +os.environ['MUJOCO_GL'] = 'egl' +os.environ["WANDB_MODE"]="offline" + +import crafter +import hydra +import lovely_tensors as lt +import numpy as np +import torch +from gym.spaces import Discrete +from omegaconf import DictConfig +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode +from torch.profiler import ProfilerActivity, profile +from tqdm import tqdm + +from rl_sandbox.utils.env import Env +from rl_sandbox.utils.logger import Logger +from rl_sandbox.utils.replay_buffer import ReplayBuffer +from rl_sandbox.utils.rollout_generation import (collect_rollout_num, + fillup_replay_buffer, + iter_rollout) + + +def val_logs(agent, val_cfg: DictConfig, metrics, env: Env, logger: Logger, global_step: int): + with torch.no_grad(): + rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent, collect_obs=True) + rollouts = [agent.preprocess(r) for r in rollouts] + + for metric in metrics: + metric.on_val(logger, rollouts, global_step) + + +@hydra.main(version_base="1.2", config_path='config', config_name='config') +def main(cfg: DictConfig): + lt.monkey_patch() + torch.distributions.Distribution.set_default_validate_args(False) + eval('setattr(torch.backends.cudnn, "benchmark", True)') # need to be pickable for multirun + torch.backends.cuda.matmul.allow_tf32 = True + + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(cfg.seed) + random.seed(cfg.seed) + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + if HydraConfig.get()['mode'] == RunMode.MULTIRUN and cfg.device_type == 'cuda': + num_gpus = torch.cuda.device_count() + gpu_id = HydraConfig.get().job.num % num_gpus + cfg.device_type = f'cuda:{gpu_id}' + cfg.logger.message += "," + ",".join(HydraConfig.get()['overrides']['task']) + + # TODO: Implement smarter techniques for exploration + # (Plan2Explore, etc) + print(f'Start run: {cfg.logger.message}') + logger = Logger(**cfg.logger, cfg=cfg) + + env: Env = hydra.utils.instantiate(cfg.env) + val_env: Env = hydra.utils.instantiate(cfg.env) + # TOOD: Create maybe some additional validation env + if cfg.env.task_name.startswith("Crafter"): + env.env = crafter.Recorder(env.env, + logger.log_dir(), + save_stats=True, + save_video=False, + save_episode=False) + + is_discrete = isinstance(env.action_space, Discrete) + agent = hydra.utils.instantiate( + cfg.agent, + obs_space_num=env.observation_space.shape, + actions_num=env.action_space.n if is_discrete else env.action_space.shape[0], + action_type='discrete' if is_discrete else 'continuous', + device_type=cfg.device_type, + f16_precision=cfg.training.f16_precision, + logger=logger) + + buff = ReplayBuffer(max_len=500_000, + prioritize_ends=cfg.training.prioritize_ends, + min_ep_len=cfg.agent.get('batch_cluster_size', 1) * + (cfg.training.prioritize_ends + 1), + preprocess_func=agent.preprocess, + device = cfg.device_type) + + fillup_replay_buffer( + env, buff, + max(cfg.training.prefill, + cfg.training.batch_size * cfg.agent.get('batch_cluster_size', 1)), + agent=agent) + + metrics = [metric(agent) for metric in hydra.utils.instantiate(cfg.validation.metrics)] + + prof = profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], + on_trace_ready=torch.profiler.tensorboard_trace_handler(logger.log_dir() + '/profiler'), + schedule=torch.profiler.schedule(wait=10, warmup=10, active=5, repeat=5), + with_stack=True) if cfg.debug.profiler else None + + for i in tqdm(range(int(cfg.training.pretrain)), desc='Pretraining'): + if cfg.training.checkpoint_path is not None: + break + rollout_chunks = buff.sample(cfg.training.batch_size, + cluster_size=cfg.agent.get( + 'batch_cluster_size', 1)) + losses = agent.train(rollout_chunks) + logger.log(losses, i, mode='pre_train') + + val_logs(agent, cfg.validation, metrics, val_env, logger, 0) + + if cfg.training.checkpoint_path is not None: + prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) + else: + prev_global_step = global_step = 0 + + pbar = tqdm(total=cfg.training.steps, desc='Training') + while global_step < cfg.training.steps: + ### Training and exploration + + for env_step in iter_rollout(env, agent): + buff.add_sample(env_step) + + if global_step % cfg.training.train_every == 0: + # NOTE: unintuitive that batch_size is now number of total + # samples, but not amount of sequences for recurrent model + rollout_chunk = buff.sample(cfg.training.batch_size, + cluster_size=cfg.agent.get( + 'batch_cluster_size', 1)) + + losses = agent.train(rollout_chunk) + if cfg.debug.profiler: + prof.step() + if global_step % 1000 == 0: + logger.log(losses, global_step, mode='train') + + for metric in metrics: + metric.on_step(logger) + + global_step += cfg.env.repeat_action_num + pbar.update(cfg.env.repeat_action_num) + + for metric in metrics: + metric.on_episode(logger, buff.rollouts[-1], global_step) + + # FIXME: find more appealing solution + ### Validation + if (global_step % cfg.training.val_logs_every) <= (prev_global_step % + cfg.training.val_logs_every): + val_logs(agent, cfg.validation, metrics, val_env, logger, global_step) + + ### Checkpoint + if (global_step % cfg.training.save_checkpoint_every) < ( + prev_global_step % cfg.training.save_checkpoint_every): + agent.save_ckpt(global_step, losses) + + prev_global_step = global_step + + if cfg.debug.profiler: + prof.stop() + + +if __name__ == "__main__": + main() diff --git a/rl_sandbox/utils/dists.py b/rl_sandbox/utils/dists.py new file mode 100644 index 0000000..759fc3c --- /dev/null +++ b/rl_sandbox/utils/dists.py @@ -0,0 +1,204 @@ +import math +from numbers import Number +import typing as t + +import numpy as np +import torch +import torch.distributions as td +from torch import nn +from torch.distributions import Distribution, constraints +from torch.distributions.utils import broadcast_all +from torch.distributions.utils import _standard_normal + +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) + + +class TruncatedStandardNormal(Distribution): + """ + Truncated Standard Normal distribution + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + arg_constraints = { + 'a': constraints.real, + 'b': constraints.real, + } + has_rsample = True + + def __init__(self, a, b, validate_args=None): + self.a, self.b = broadcast_all(a, b) + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + super(TruncatedStandardNormal, self).__init__(batch_shape, validate_args=validate_args) + if self.a.dtype != self.b.dtype: + raise ValueError('Truncation bounds types are different') + if any((self.a >= self.b).view(-1,).tolist()): + raise ValueError('Incorrect truncation range') + eps = torch.finfo(self.a.dtype).eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp_min(eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = (self._little_phi_b * little_phi_coeff_b - self._little_phi_a * little_phi_coeff_a) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = 1 - self._lpbb_m_lpaa_d_Z - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) + + @property + def mean(self): + return self._mean + + @property + def variance(self): + return self._variance + + def entropy(self): + return self._entropy + + @property + def auc(self): + return self._Z + + @staticmethod + def _little_phi(x): + return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI + + @staticmethod + def _big_phi(x): + return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf(self, value): + return self._inv_big_phi(self._big_phi_a + value * self._Z) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5 + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1) + return self.icdf(p) + +class TruncatedNormal(td.Normal): + def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): + super().__init__(loc, scale, validate_args=False) + self.low = low + self.high = high + self.eps = eps + + def _clamp(self, x): + clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) + x = x - x.detach() + clamped_x.detach() + return x + + def sample(self, sample_shape=torch.Size(), clip=None): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, + dtype=self.loc.dtype, + device=self.loc.device) + eps *= self.scale + if clip is not None: + eps = torch.clamp(eps, -clip, clip) + x = self.loc + eps + return self._clamp(x) + + +class Sigmoid2(nn.Module): + def forward(self, x): + return 2*torch.sigmoid(x/2) + +class NormalWithOffset(nn.Module): + def __init__(self, min_std: float, std_trans: str = 'sigmoid2', transform: t.Optional[str] = None): + super().__init__() + self.min_std = min_std + match std_trans: + case 'identity': + self.std_trans = nn.Identity() + case 'softplus': + self.std_trans = nn.Softplus() + case 'sigmoid': + self.std_trans = nn.Sigmoid() + case 'sigmoid2': + self.std_trans = Sigmoid2() + case _: + raise RuntimeError("Unknown std transformation") + + match transform: + case 'tanh': + self.trans = [td.TanhTransform(cache_size=1)] + case None: + self.trans = None + case _: + raise RuntimeError("Unknown distribution transformation") + + def forward(self, x): + mean, std = x.chunk(2, dim=-1) + dist = td.Normal(mean, self.std_trans(std) + self.min_std) + if self.trans is None: + return dist + else: + return td.TransformedDistribution(dist, self.trans) + +class DistLayer(nn.Module): + def __init__(self, type: str): + super().__init__() + self._dist = type + match type: + case 'mse': + self.dist = lambda x: td.Normal(x.float(), 1.0) + case 'normal': + self.dist = NormalWithOffset(min_std=0.1) + case 'onehot': + # Forcing float32 on AMP + self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x.float()) + case 'normal_tanh': + def get_tanh_normal(x, min_std=0.1): + mean, std = x.chunk(2, dim=-1) + init_std = np.log(np.exp(5) - 1) + raise NotImplementedError() + # return TanhNormal(torch.clamp(mean, -9.0, 9.0).float(), (F.softplus(std + init_std) + min_std).float(), upscale=5) + self.dist = get_tanh_normal + case 'normal_trunc': + def get_trunc_normal(x, min_std=0.1): + mean, std = x.chunk(2, dim=-1) + return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float()) + self.dist = get_trunc_normal + case 'binary': + self.dist = lambda x: td.Bernoulli(logits=x.float()) + case _: + raise RuntimeError("Invalid dist layer") + + def forward(self, x): + match self._dist: + case 'onehot': + return self.dist(x) + case _: + # FIXME: verify dimensionality of independent + return td.Independent(self.dist(x), 1) + diff --git a/rl_sandbox/utils/dm_control.py b/rl_sandbox/utils/dm_control.py new file mode 100644 index 0000000..92bc6f6 --- /dev/null +++ b/rl_sandbox/utils/dm_control.py @@ -0,0 +1,39 @@ +import numpy as np +from dm_env import specs +from nptyping import Float, Int, NDArray, Shape + + +# TODO: add tests +class ActionDiscritizer: + def __init__(self, action_spec: specs.BoundedArray, values_per_dim: int): + self.actions_dim = action_spec.shape[0] + self.min = action_spec.minimum + self.max = action_spec.maximum + self.per_dim = values_per_dim + self.shape = self.per_dim**self.actions_dim + + # actions_dim X per_dim + self.grid = np.stack([np.linspace(min, max, self.per_dim, endpoint=True) for min, max in zip(self.min, self.max)]) + + def discretize(self, action: NDArray[Shape['*'], Float]) -> NDArray[Shape['*'], Int]: + ks = np.argmin((self.grid - np.ones((self.per_dim, 1)).dot(action).T)**2, axis=1) + a = 0 + for i, k in enumerate(ks): + a += k*self.per_dim**i + # ret_a = np.zeros(self.shape, dtype=np.int64) + # ret_a[a] = 1 + # return ret_a + return a + + def undiscretize(self, action: NDArray[Shape['*'], Int]) -> NDArray[Shape['*'], Float]: + ks = [] + # k = np.argmax(action) + k = action + for i in range(self.per_dim - 1, -1, -1): + ks.append(k // self.per_dim**i) + k -= ks[-1] * self.per_dim**i + + a = [] + for k, vals in zip(reversed(ks), self.grid): + a.append(vals[k]) + return np.array(a) diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py new file mode 100644 index 0000000..f4ae78b --- /dev/null +++ b/rl_sandbox/utils/env.py @@ -0,0 +1,296 @@ +import typing as t +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass + +import gym +import numpy as np +from dm_control import suite +from dm_env import Environment as dmEnviron +from dm_env import TimeStep +from nptyping import Float, Int, NDArray, Shape + +Observation = NDArray[Shape["*,*,3"], Int] +State = NDArray[Shape["*"], Float] +Action = NDArray[Shape["*"], Int] + + +@dataclass +class EnvStepResult: + obs: Observation | State + reward: float + terminated: bool + + +class ActionTransformer(metaclass=ABCMeta): + + def set_env(self, env: 'Env'): + self.low = env.action_space.low + self.high = env.action_space.high + + @abstractmethod + def transform_action(self, action): + ... + + @abstractmethod + def transform_space(self, space: gym.spaces.Box): + ... + + +class ActionNormalizer(ActionTransformer): + + def set_env(self, env: 'Env'): + super().set_env(env) + if (~np.isfinite(self.low) | ~np.isfinite(self.high)).any(): + raise RuntimeError("Not bounded space cannot be normalized") + + def transform_action(self, action): + return (self.high - self.low) * (action + 1) / 2 + self.low + + def transform_space(self, space: gym.spaces.Box): + return gym.spaces.Box(-np.ones_like(self.low), + np.ones_like(self.high), + dtype=np.float32) + + +class ActionDisritezer(ActionTransformer): + + def __init__(self, actions_num: int): + self.per_dim = actions_num + + def set_env(self, env: 'Env'): + super().set_env(env) + if (~np.isfinite(self.low) | ~np.isfinite(self.high)).any(): + raise RuntimeError("Not bounded space cannot be discritized") + + self.grid = np.stack([ + np.linspace(min, max, self.per_dim, endpoint=True) + for min, max in zip(self.low, self.high) + ]) + + def transform_action(self, action: NDArray[Shape['*'], + Int]) -> NDArray[Shape['*'], Float]: + ks = [] + k = action + for i in range(self.per_dim - 1, -1, -1): + ks.append(k // self.per_dim**i) + k -= ks[-1] * self.per_dim**i + + a = [] + for k, vals in zip(reversed(ks), self.grid): + a.append(vals[k]) + return np.array(a) + + def transform_space(self, space: gym.spaces.Box): + return gym.spaces.Box(0, self.per_dim**len(self.low)-1, dtype=np.int32) + + +class Env(metaclass=ABCMeta): + + def __init__(self, run_on_pixels: bool, obs_res: tuple[int, int], + repeat_action_num: int, transforms: list[ActionTransformer]): + self.obs_res = obs_res + self.run_on_pixels = run_on_pixels + self.repeat_action_num = repeat_action_num + assert self.repeat_action_num >= 1 + self.ac_trans = [] + for t in transforms: + t.set_env(self) + self.ac_trans.append(t) + + def step(self, action: Action) -> EnvStepResult: + for t in reversed(self.ac_trans): + action = t.transform_action(action) + return self._step(action, self.repeat_action_num) + + @abstractmethod + def _step(self, action: Action, repeat_num: int = 1) -> EnvStepResult: + pass + + @abstractmethod + def reset(self) -> EnvStepResult: + pass + + @abstractmethod + def _observation_space(self) -> gym.Space: + pass + + @abstractmethod + def _action_space(self) -> gym.Space: + ... + + @property + def observation_space(self) -> gym.Space: + return self._observation_space() + + @property + def action_space(self) -> gym.Space: + space = self._action_space() + for t in self.ac_trans: + space = t.transform_space(t) + return space + +class AtariEnv(Env): + + def __init__(self, task_name: str, obs_res: tuple[int, int], sticky: bool, life_done: bool, greyscale: bool, + repeat_action_num: int, transforms: list[ActionTransformer]): + import gym.wrappers + import gym.envs.atari + super().__init__(True, obs_res, repeat_action_num, transforms) + + self.env: gym.Env = gym.envs.atari.AtariEnv(game=task_name, obs_type='rgb', frameskip=1, repeat_action_probability=0.25 if sticky else 0, full_action_space=False) + # Tell wrapper that the inner env has no action repeat. + self.env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') + self.env = gym.wrappers.AtariPreprocessing(self.env, + 30, repeat_action_num, obs_res[0], + life_done, greyscale) + + + def render(self): + raise RuntimeError("Render is not supported for AtariEnv") + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + rew = 0 + for _ in range(repeat_num - 1): + new_state, reward, terminated, _ = self.env.step(action) + ts = EnvStepResult(new_state, reward, terminated) + if terminated: + break + rew += reward or 0.0 + if repeat_num == 1 or not terminated: + new_state, reward, terminated, _ = self.env.step(action) + env_res = EnvStepResult(new_state, reward, terminated) + else: + env_res = ts + env_res.reward = rew + (env_res.reward or 0.0) + return env_res + + def reset(self): + state = self.env.reset() + return EnvStepResult(state, 0, False) + + def _observation_space(self): + return self.env.observation_space + + def _action_space(self): + return self.env.action_space + +class GymEnv(Env): + + def __init__(self, task_name: str, run_on_pixels: bool, obs_res: tuple[int, int], + repeat_action_num: int, transforms: list[ActionTransformer]): + super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) + + self.task_name = task_name + if self.task_name.startswith('Crafter'): + import crafter + self.env: gym.Env = gym.make(task_name) + + if run_on_pixels: + raise NotImplementedError("Run on pixels supported only for 'dm_control'") + + def render(self): + raise RuntimeError("Render is not supported for GymEnv") + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + rew = 0 + for _ in range(repeat_num - 1): + new_state, reward, terminated, _ = self.env.step(action) + ts = EnvStepResult(new_state, reward, terminated) + if terminated: + break + rew += reward or 0.0 + if repeat_num == 1 or not terminated: + new_state, reward, terminated, _ = self.env.step(action) + env_res = EnvStepResult(new_state, reward, terminated) + else: + env_res = ts + env_res.reward = rew + (env_res.reward or 0.0) + return env_res + + def reset(self): + state = self.env.reset() + return EnvStepResult(state, 0, False) + + def _observation_space(self): + return self.env.observation_space + + def _action_space(self): + return self.env.action_space + +class MockEnv(Env): + + def __init__(self, run_on_pixels: bool, + obs_res: tuple[int, int], repeat_action_num: int, + transforms: list[ActionTransformer]): + super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) + self.max_steps = 255 + self.step_count = 0 + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + self.step_count += repeat_num + return EnvStepResult(self.render(), self.step_count, self.step_count >= self.max_steps) + + def reset(self): + self.step_count = 0 + return EnvStepResult(self.render(), 0, False) + + def render(self): + return np.ones(self.obs_res + (3, )) * self.step_count + + def _observation_space(self): + return gym.spaces.Box(0, 255, self.obs_res + (3, ), dtype=np.uint8) + + def _action_space(self): + return gym.spaces.Box(-1, 1, (1, ), dtype=np.float32) + + +class DmEnv(Env): + + def __init__(self, run_on_pixels: bool, + camera_id: int, + obs_res: tuple[int, int], repeat_action_num: int, + domain_name: str, task_name: str, transforms: list[ActionTransformer]): + self.camera_id = camera_id + self.env: dmEnviron = suite.load(domain_name=domain_name, task_name=task_name) + super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) + + def render(self): + return self.env.physics.render(*self.obs_res, camera_id=self.camera_id) + + def _uncode_ts(self, ts: TimeStep) -> EnvStepResult: + if self.run_on_pixels: + state = self.render() + else: + state = ts.observation + state = np.concatenate([state[s] for s in state], dtype=np.float32) + return EnvStepResult(state, ts.reward, ts.last()) + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + rew = 0 + for _ in range(repeat_num - 1): + ts = self.env.step(action) + if ts.last(): + break + rew += ts.reward or 0.0 + if repeat_num == 1 or not ts.last(): + env_res = self._uncode_ts(self.env.step(action)) + else: + env_res = ts + env_res.reward = rew + (env_res.reward or 0.0) + return env_res + + def reset(self) -> EnvStepResult: + return self._uncode_ts(self.env.reset()) + + def _observation_space(self): + if self.run_on_pixels: + return gym.spaces.Box(0, 255, self.obs_res + (3, ), dtype=np.uint8) + else: + raise NotImplementedError( + "Currently run on pixels is only supported for 'dm_control'") + # for space in self.env.observation_spec(): + # obs_space_num = sum([v.shape[0] for v in env.observation_space().values()]) + + def _action_space(self): + spec = self.env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index 85e2f09..ebf6320 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -1,15 +1,23 @@ +import typing as t from torch import nn -def fc_nn_generator(obs_space_num: int, - action_space_num: int, - hidden_layer_size: int, - num_layers: int): +def fc_nn_generator(input_num: int, + output_num: int, + hidden_size: int, + num_layers: int, + intermediate_activation: t.Type[nn.Module] = nn.ReLU, + final_activation: nn.Module = nn.Identity(), + layer_norm: bool = False): + norm_layer = nn.LayerNorm if layer_norm else nn.Identity + assert num_layers >= 3 layers = [] - layers.append(nn.Linear(obs_space_num, hidden_layer_size)) - layers.append(nn.ReLU(inplace=True)) - for _ in range(num_layers): - layers.append(nn.Linear(hidden_layer_size, hidden_layer_size)) - layers.append(nn.ReLU(inplace=True)) - layers.append(nn.Linear(hidden_layer_size, action_space_num)) - layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(input_num, hidden_size)) + layers.append(nn.LayerNorm(hidden_size)) + layers.append(intermediate_activation(inplace=True)) + for _ in range(num_layers - 2): + layers.append(nn.Linear(hidden_size, hidden_size)) + layers.append(norm_layer(hidden_size)) + layers.append(intermediate_activation(inplace=True)) + layers.append(nn.Linear(hidden_size, output_num)) + layers.append(final_activation) return nn.Sequential(*layers) diff --git a/rl_sandbox/utils/logger.py b/rl_sandbox/utils/logger.py new file mode 100644 index 0000000..5ff9b97 --- /dev/null +++ b/rl_sandbox/utils/logger.py @@ -0,0 +1,99 @@ +from torch.utils.tensorboard.writer import SummaryWriter +import wandb +import typing as t +import omegaconf +from flatten_dict import flatten + + +class SummaryWriterMock(): + def __init__(self): + self.log_dir = None + + def add_scalar(*args, **kwargs): + pass + + def add_video(*args, **kwargs): + pass + + def add_image(*args, **kwargs): + pass + + def add_histogram(*args, **kwargs): + pass + + def add_figure(*args, **kwargs): + pass + +class WandbWriter(): + def __init__(self, project: str, comment: str, cfg: t.Optional[omegaconf.DictConfig]): + self.run = wandb.init( + project=project, + name=comment, + notes=comment, + config=flatten(omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), reducer=lambda x, y: f"{x}-{y}" if x is not None else y) if cfg else None + ) + self.log_dir = wandb.run.dir + + def add_scalar(self, name: str, value: t.Any, global_step: int): + wandb.log({name: value}, step=global_step) + + def add_image(self, name: str, image: t.Any, global_step: int, dataformats: str = 'CHW'): + match dataformats: + case "CHW": + mode = "RGB" + case "HW": + mode = "L" + case _: + raise RuntimeError("Not supported dataformat") + wandb.log({name: wandb.Image(image, mode=mode)}, step=global_step) + + def add_video(self, name: str, video: t.Any, global_step: int, fps: int): + wandb.log({name: wandb.Video(video[0], fps=fps)}, step=global_step) + + def add_figure(self, name: str, figure: t.Any, global_step: int): + wandb.log({name: wandb.Image(figure)}, step=global_step) + +class Logger: + def __init__(self, type: t.Optional[str], + cfg: t.Optional[omegaconf.DictConfig] = None, + project: t.Optional[str] = None, + message: t.Optional[str] = None, + log_grads: bool = True, + log_dir: t.Optional[str] = None + ) -> None: + self.type = type + msg = message or "" + match type: + case "tensorboard": + self.writer = SummaryWriter(comment=msg, log_dir=log_dir) + case "wandb": + self.writer = WandbWriter(project=project, comment=msg, cfg=cfg) + case None: + self.writer = SummaryWriterMock() + case _: + raise ValueError(f"Unknown logger type: {type}") + self.log_grads = log_grads + + + def log(self, losses: dict[str, t.Any], global_step: int, mode: str = 'train'): + for loss_name, loss in losses.items(): + if 'grad' in loss_name: + if self.log_grads: + self.writer.add_histogram(f'{mode}/{loss_name}', loss, global_step) + else: + self.writer.add_scalar(f'{mode}/{loss_name}', loss.item(), global_step) + + def add_scalar(self, name: str, value: t.Any, global_step: int): + self.writer.add_scalar(name, value, global_step) + + def add_image(self, name: str, image: t.Any, global_step: int, dataformats: str = 'CHW'): + self.writer.add_image(name, image, global_step, dataformats=dataformats) + + def add_video(self, name: str, video: t.Any, global_step: int): + self.writer.add_video(name, video, global_step, fps=20) + + def add_figure(self, name: str, figure: t.Any, global_step: int): + self.writer.add_figure(name, figure, global_step) + + def log_dir(self) -> str: + return self.writer.log_dir diff --git a/rl_sandbox/utils/optimizer.py b/rl_sandbox/utils/optimizer.py new file mode 100644 index 0000000..77a176f --- /dev/null +++ b/rl_sandbox/utils/optimizer.py @@ -0,0 +1,71 @@ +import typing as t +from collections.abc import Iterable +import torch +import numpy as np +from torch import nn +from torch.optim.lr_scheduler import LRScheduler +from torch.cuda.amp import GradScaler + +from torch.optim.lr_scheduler import LinearLR, LambdaLR + +class WarmupScheduler(LinearLR): + def __init__(self, optimizer, warmup_steps): + super().__init__(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) + +# class WarmupScheduler(LambdaLR): +# def __init__(self, optimizer, warmup_steps): +# super().__init__(optimizer, lambda epoch: min(1, np.interp(epoch, [1, warmup_steps], [0, 1])) ) + +class DecayScheduler(LambdaLR): + def __init__(self, optimizer, decay_steps, decay_rate): + super().__init__(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) + +class Optimizer: + def __init__(self, model, + lr=1e-4, + eps=1e-8, + weight_decay=0.01, + lr_scheduler: t.Optional[t.Type[LRScheduler] | t.Iterable[t.Type[LRScheduler]]] = None, + scaler: bool = False, + log_grad: bool = False, + clip: t.Optional[float] = None): + self.model = model + self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay) + self.lr_scheduler = lr_scheduler + if lr_scheduler is not None and not isinstance(lr_scheduler, Iterable): + self.lr_scheduler = lr_scheduler(optimizer=self.optimizer) + elif isinstance(lr_scheduler, Iterable): + self.lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_sched(optimizer=self.optimizer) for lr_sched in lr_scheduler]) + self.log_grad = log_grad + self.scaler = GradScaler() if scaler else None + self.clip = clip + + def step(self, loss): + metrics = {} + self.optimizer.zero_grad(set_to_none=True) + + if self.scaler: + loss = self.scaler.scale(loss) + loss.backward() + + if self.scaler: + self.scaler.unscale_(self.optimizer) + + if self.log_grad: + for tag, value in self.model.named_parameters(): + metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() + + if self.clip: + nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) + + if self.scaler: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + if self.lr_scheduler: + self.lr_scheduler.step() + metrics[f'lr/{self.model.__class__.__name__}'] = torch.Tensor(self.lr_scheduler.get_last_lr()) + + return metrics diff --git a/rl_sandbox/utils/persistent_replay_buffer.py b/rl_sandbox/utils/persistent_replay_buffer.py new file mode 100644 index 0000000..bacf666 --- /dev/null +++ b/rl_sandbox/utils/persistent_replay_buffer.py @@ -0,0 +1,110 @@ +import typing as t +from collections import deque +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import webdataset as wds + +from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, + Observations, Rewards, Rollout, + State, States, TerminationFlags) + + +# TODO: add tagging of replay buffer meta-data (env config) +# to omit incompatible cache +class PersistentReplayBuffer: + + def __init__(self, directory: Path, max_len=1e6): + self.max_len: int = int(max_len) + self.directory = directory + self.directory.mkdir(exist_ok=True) + self.rollouts: list[str] = list(map(str, self.directory.glob('*.tar'))) + self.rollouts_num = len(self.rollouts) + # FIXME: add correct length calculation, currently hardcoded + self.rollouts_len: list[int] = [200] * self.rollouts_num + self.total_num = sum(self.rollouts_len) + self.rollout_idx = self.rollouts_num + + self.curr_rollout: t.Optional[Rollout] = None + self.rollouts_changed: bool = True + + def add_rollout(self, rollout: Rollout): + name = str(self.directory / f'rollout-{self.rollout_idx % self.max_len}.tar') + sink = wds.TarWriter(name) + + for idx in range(len(rollout)): + s, a, r, t = rollout.states[idx], rollout.actions[idx], rollout.rewards[ + idx], rollout.is_finished[idx] + sink.write({ + "__key__": "sample%06d" % idx, + "state.pyd": s, + "action.pyd": a, + "reward.pyd": np.array(r, dtype=np.float32), + "is_finished.pyd": np.array(t, dtype=np.bool_) + }) + + if self.rollout_idx < self.max_len: + self.total_num += len(rollout) + self.rollouts_num += 1 + self.rollouts.append(name) + self.rollouts_len.append(len(rollout)) + else: + self.total_num += len(rollout) - self.rollouts_len[self.rollout_idx % + self.max_len] + self.rollouts[self.rollout_idx % self.max_len] = name + self.rollouts_len[self.rollout_idx % self.max_len] = len(rollout) + self.rollout_idx += 1 + self.rollouts_changed = True + + # Add sample expects that each subsequent sample + # will be continuation of last rollout util termination flag true + # is encountered + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + if self.curr_rollout is None: + self.curr_rollout = Rollout([s], [a], [r], None, [f]) + else: + self.curr_rollout.states.append(s) + self.curr_rollout.actions.append(a) + self.curr_rollout.rewards.append(r) + self.curr_rollout.is_finished.append(f) + + if f: + self.add_rollout( + Rollout(np.array(self.curr_rollout.states), + np.array(self.curr_rollout.actions), + np.array(self.curr_rollout.rewards, dtype=np.float32), + np.array([n]), np.array(self.curr_rollout.is_finished))) + self.curr_rollout = None + + def can_sample(self, num: int): + return self.total_num >= num + + @staticmethod + def add_next(src): + s, a, r, t = src + return s[:-1], a[:-1], r[:-1], s[1:], t[:-1] + + def sample( + self, + batch_size: int, + cluster_size: int = 1 + ) -> tuple[States, Actions, Rewards, States, TerminationFlags]: + seq_num = batch_size // cluster_size + # TODO: Could be done in async before + # NOTE: maybe use WDS_REWRITE + + if self.rollouts_changed: + # NOTE: shardshuffle will specify amount of urls that will be taken + # into account. Sorting not everything doesn't make sense + self.dataset = wds.WebDataset(self.rollouts + ).decode().to_tuple("state.pyd", "action.pyd", "reward.pyd", "is_finished.pyd" + # NOTE: does not take into account is_finished + ).batched(cluster_size + 1, partial=False + ).map(self.add_next).batched(seq_num) + # NOTE: in WebDataset github, it is recommended to use such batching by ourselves + # https://github.com/webdataset/webdataset#dataloader + self.loader = iter( + wds.WebLoader(self.dataset, batch_size=None, + num_workers=4, pin_memory=True).unbatched().shuffle(1000).unbatched().batched(batch_size)) + return next(self.loader) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index ffdb8d4..e916ef9 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -1,47 +1,163 @@ -import random import typing as t from collections import deque +from dataclasses import dataclass, field +from unpackable import unpack +import torch import numpy as np -from nptyping import Bool, Int, Float, NDArray, Shape +from jaxtyping import Bool, Float, Int -State = NDArray[Shape["*"],Float] -Action = NDArray[Shape["*"],Int] +Observation = Int[torch.Tensor, 'n n 3'] +State = Float[torch.Tensor, 'n'] +Action = Int[torch.Tensor, 'n'] -States = NDArray[Shape["*,*"],Float] -Actions = NDArray[Shape["*,*"],Int] -Rewards = NDArray[Shape["*"],Float] -TerminationFlag = NDArray[Shape["*"],Bool] +Observations = Int[torch.Tensor, 'batch n n 3'] +States = Float[torch.Tensor, 'batch n'] +Actions = Int[torch.Tensor, 'batch n'] +Rewards = Float[torch.Tensor, 'batch'] +TerminationFlags = Bool[torch.Tensor, 'batch'] +IsFirstFlags = TerminationFlags +@dataclass +class EnvStep: + obs: Observation + action: Action + reward: float + is_finished: bool + is_first: bool + additional_data: dict[str, Float[torch.Tensor, '...']] = field(default_factory=dict) + +@dataclass +class Rollout: + obs: Observations + actions: Actions + rewards: Rewards + is_finished: TerminationFlags + is_first: IsFirstFlags + additional_data: dict[str, Float[torch.Tensor, 'batch ...']] = field(default_factory=dict) + + def __len__(self): + return len(self.obs) + + def to(self, device: str, non_blocking: bool = False): + self.obs = self.obs.to(device, non_blocking=True) + self.actions = self.actions.to(device, non_blocking=True) + self.rewards = self.rewards.to(device, non_blocking=True) + self.is_finished = self.is_finished.to(device, non_blocking=True) + self.is_first = self.is_first.to(device, non_blocking=True) + for k, v in self.additional_data.items(): + self.additional_data[k] = v.to(device, non_blocking = True) + if not non_blocking: + torch.cuda.current_stream().synchronize() + return self + +@dataclass +class RolloutChunks(Rollout): + pass -# ReplayBuffer consists of next triplets: (s, a, r) class ReplayBuffer: - def __init__(self, max_len=10_000): + + def __init__(self, max_len=2e6, + prioritize_ends: bool = False, + min_ep_len: int = 1, + preprocess_func: t.Callable[[Rollout], Rollout] = lambda x: x, + device: str = 'cpu'): + self.rollouts: deque[Rollout] = deque() + self.rollouts_len: deque[int] = deque() + self.curr_rollout = None + self.min_ep_len = min_ep_len + self.prioritize_ends = prioritize_ends self.max_len = max_len - self.states: States = np.array([]) - self.actions: Actions = np.array([]) - self.rewards: Rewards = np.array([]) - self.next_states: States = np.array([]) - - def add_rollout(self, s: States, a: Actions, r: Rewards, n: States, f: TerminationFlag): - if len(self.states) == 0: - self.states = s - self.actions = a - self.rewards = r - self.next_states = n - self.is_finished = f + self.total_num = 0 + self.device = device + self.preprocess_func = preprocess_func + + def __len__(self): + return self.total_num + + def add_rollout(self, rollout: Rollout): + if len(rollout.obs) <= self.min_ep_len: + return + self.rollouts.append(self.preprocess_func(rollout).to(device='cpu')) + self.total_num += len(self.rollouts[-1].rewards) + self.rollouts_len.append(len(self.rollouts[-1].rewards)) + + while self.total_num >= self.max_len: + self.total_num -= self.rollouts_len[0] + self.rollouts_len.popleft() + self.rollouts.popleft() + + # Add sample expects that each subsequent sample + # will be continuation of last rollout util termination flag true + # is encountered + def add_sample(self, env_step: EnvStep): + s, a, r, n, f, additional = unpack(env_step) + if self.curr_rollout is None: + self.curr_rollout = Rollout([s], [a], [r], [n], [f], {k: [v] for k,v in additional.items()}) else: - self.states = np.concatenate([self.states, s]) - self.actions = np.concatenate([self.actions, a]) - self.rewards = np.concatenate([self.rewards, r]) - self.next_states = np.concatenate([self.next_states, n]) - self.is_finished = np.concatenate([self.is_finished, f]) + self.curr_rollout.obs.append(s) + self.curr_rollout.actions.append(a) + self.curr_rollout.rewards.append(r) + self.curr_rollout.is_finished.append(n) + self.curr_rollout.is_first.append(f) + for k,v in additional.items(): + self.curr_rollout.additional_data[k].append(v) + + if f: + self.add_rollout( + Rollout( + torch.stack(self.curr_rollout.obs), + torch.stack(self.curr_rollout.actions).reshape(len(self.curr_rollout.actions), -1), + torch.Tensor(self.curr_rollout.rewards), + torch.Tensor(self.curr_rollout.is_finished), + torch.Tensor(self.curr_rollout.is_first), + {k: torch.stack(v) for k,v in self.curr_rollout.additional_data.items()}) + ) + self.curr_rollout = None def can_sample(self, num: int): - return len(self.states) >= num + return self.total_num >= num + + def sample( + self, + batch_size: int, + cluster_size: int = 1 + ) -> RolloutChunks: + # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me + s, a, r, t, is_first, additional = [], [], [], [], [], {} + r_indeces = np.random.choice(len(self.rollouts), batch_size) + s_indeces = [] + for r_idx in r_indeces: + rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] + + assert r_len > cluster_size - 1, "Rollout it too small" + max_idx = r_len - cluster_size + 1 + if self.prioritize_ends: + max_idx += cluster_size + s_idx = min(np.random.randint(max_idx), r_len - cluster_size) + s_indeces.append(s_idx) + + is_first.append(torch.zeros(cluster_size)) + is_first[-1][0] = 1 + + s.append(rollout.obs[s_idx:s_idx + cluster_size]) + a.append(rollout.actions[s_idx:s_idx + cluster_size]) + r.append(rollout.rewards[s_idx:s_idx + cluster_size]) + t.append(rollout.is_finished[s_idx:s_idx + cluster_size]) + for k,v in rollout.additional_data.items(): + if k not in additional: + additional[k] = [] + additional[k].append(v[s_idx:s_idx + cluster_size]) + + return RolloutChunks( + obs=torch.cat(s), + actions=torch.cat(a), + rewards=torch.cat(r).float(), + is_finished=torch.cat(t), + is_first=torch.cat(is_first), + additional_data={k: torch.cat(v) for k,v in additional.items()} + ).to(self.device, non_blocking=False) + - def sample(self, num: int) -> t.Tuple[States, Actions, Rewards, States, TerminationFlag]: - indeces = list(range(len(self.states))) - random.shuffle(indeces) - indeces = indeces[:num] - return self.states[indeces], self.actions[indeces], self.rewards[indeces], self.next_states[indeces], self.is_finished[indeces] +# TODO: +# [ ] (Optional) Utilize torch's dataloader for async sampling diff --git a/rl_sandbox/utils/replay_buffer_old.py b/rl_sandbox/utils/replay_buffer_old.py new file mode 100644 index 0000000..4572376 --- /dev/null +++ b/rl_sandbox/utils/replay_buffer_old.py @@ -0,0 +1,138 @@ +import typing as t +from collections import deque +from dataclasses import dataclass + +import numpy as np +from nptyping import Bool, Float, Int, NDArray, Shape + +Observation = NDArray[Shape["*,*,3"], Int] +State = NDArray[Shape["*"], Float] | Observation +Action = NDArray[Shape["*"], Int] + +Observations = NDArray[Shape["*,*,*,3"], Int] +States = NDArray[Shape["*,*"], Float] | Observations +Actions = NDArray[Shape["*,*"], Int] +Rewards = NDArray[Shape["*"], Float] +TerminationFlags = NDArray[Shape["*"], Bool] +IsFirstFlags = TerminationFlags + + +@dataclass +class Rollout: + states: States + actions: Actions + rewards: Rewards + next_states: States + is_finished: TerminationFlags + observations: t.Optional[Observations] = None + + def __len__(self): + return len(self.states) + +# TODO: make buffer concurrent-friendly +class ReplayBuffer: + + def __init__(self, max_len=2e6, + prioritize_ends: bool = False, + min_ep_len: int = 1, + device: str = 'cpu'): + self.rollouts: deque[Rollout] = deque() + self.rollouts_len: deque[int] = deque() + self.curr_rollout = None + self.min_ep_len = min_ep_len + self.prioritize_ends = prioritize_ends + self.max_len = max_len + self.total_num = 0 + self.device = device + + def __len__(self): + return self.total_num + + def add_rollout(self, rollout: Rollout): + if len(rollout.next_states) <= self.min_ep_len: + return + # NOTE: only last next state is stored, all others are induced + # from state on next step + rollout.next_states = np.expand_dims(rollout.next_states[-1], 0) + self.rollouts.append(rollout) + self.total_num += len(self.rollouts[-1].rewards) + self.rollouts_len.append(len(self.rollouts[-1].rewards)) + + while self.total_num >= self.max_len: + self.total_num -= self.rollouts_len[0] + self.rollouts_len.popleft() + self.rollouts.popleft() + + # Add sample expects that each subsequent sample + # will be continuation of last rollout util termination flag true + # is encountered + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + if self.curr_rollout is None: + self.curr_rollout = Rollout([s], [a], [r], None, [f]) + else: + self.curr_rollout.states.append(s) + self.curr_rollout.actions.append(a) + self.curr_rollout.rewards.append(r) + self.curr_rollout.is_finished.append(f) + + if f: + self.add_rollout( + Rollout(np.array(self.curr_rollout.states), + np.array(self.curr_rollout.actions).reshape(len(self.curr_rollout.actions), -1), + np.array(self.curr_rollout.rewards, dtype=np.float32), + np.array([n]), np.array(self.curr_rollout.is_finished))) + self.curr_rollout = None + + def can_sample(self, num: int): + return self.total_num >= num + + def sample( + self, + batch_size: int, + cluster_size: int = 1 + ) -> tuple[States, Actions, Rewards, States, TerminationFlags, IsFirstFlags]: + # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me + s, a, r, n, t, is_first = [], [], [], [], [], [] + do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > (cluster_size * (self.prioritize_ends + 1)) + tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) + r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), + batch_size, + p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / tot) + s_indeces = [] + for r_idx in r_indeces: + if r_idx != len(self.rollouts): + rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] + else: + # -1 because we don't have next_state on terminal + rollout, r_len = self.curr_rollout, len(self.curr_rollout.states) - 1 + + assert r_len > cluster_size - 1, "Rollout it too small" + max_idx = r_len - cluster_size + 1 + if self.prioritize_ends: + s_idx = np.random.choice(max_idx - cluster_size + 1, 1).item() + cluster_size - 1 + else: + s_idx = np.random.choice(max_idx, 1).item() + s_indeces.append(s_idx) + + if r_idx == len(self.rollouts): + r_len += 1 + # FIXME: hot-fix for 1d action space, better to find smarter solution + actions = np.array(rollout.actions[s_idx:s_idx + cluster_size]).reshape(cluster_size, -1) + else: + actions = rollout.actions[s_idx:s_idx + cluster_size] + + is_first.append(np.zeros(cluster_size)) + if s_idx == 0: + is_first[-1][0] = 1 + s.append(rollout.states[s_idx:s_idx + cluster_size]) + a.append(actions) + r.append(rollout.rewards[s_idx:s_idx + cluster_size]) + t.append(rollout.is_finished[s_idx:s_idx + cluster_size]) + if s_idx != r_len - cluster_size: + n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size]) + else: + if cluster_size != 1: + n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size - 1]) + n.append(rollout.next_states) + return (np.concatenate(s), np.concatenate(a), np.concatenate(r, dtype=np.float32), + np.concatenate(n), np.concatenate(t), np.concatenate(is_first)) diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index bdbd503..c08012c 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -1,41 +1,122 @@ import typing as t +from collections import defaultdict +from multiprocessing.synchronize import Lock -import gym import numpy as np +import torch +import torch.multiprocessing as mp +from IPython.core.inputtransformer2 import warnings +from unpackable import unpack -from rl_sandbox.utils.replay_buffer import (Actions, ReplayBuffer, Rewards, - States, TerminationFlag) +from rl_sandbox.agents.random_agent import RandomAgent +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.env import Env +from rl_sandbox.utils.replay_buffer import EnvStep, ReplayBuffer, Rollout +# (Action, Observation, ReplayBuffer, Rollout, State) -def collect_rollout(env: gym.Env, agent: t.Optional[t.Any] = None) -> t.Tuple[States, Actions, Rewards, States, TerminationFlag]: - s, a, r, n, f = [], [], [], [], [] +# FIXME: obsolete, need to be updated for new replay buffer +# def _async_env_worker(env: Env, obs_queue: mp.Queue, act_queue: mp.Queue): +# state, _, terminated = unpack(env.reset()) +# obs_queue.put((state, 0, terminated), block=False) - obs, _ = env.reset() - terminated = False +# while not terminated: +# action = act_queue.get(block=True) + +# new_state, reward, terminated = unpack(env.step(action)) +# del action +# obs_queue.put((state, reward, terminated), block=False) + +# state = new_state + +# def iter_rollout_async( +# env: Env, +# agent: RlAgent +# ) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, +# None]: +# # NOTE: maybe use SharedMemory instead +# obs_queue = mp.Queue(1) +# a_queue = mp.Queue(1) +# p = mp.Process(target=_async_env_worker, args=(env, obs_queue, a_queue)) +# p.start() +# terminated = False + +# while not terminated: +# state, reward, terminated = obs_queue.get(block=True) +# action = agent.get_action(state) +# a_queue.put(action) +# yield state, action, reward, None, terminated, state + + +def iter_rollout(env: Env, + agent: RlAgent, + collect_obs: bool = False) -> t.Generator[EnvStep, None, None]: + state, _, terminated = unpack(env.reset()) + agent.reset() + + reward = 0.0 + is_first = True + with torch.no_grad(): + action = torch.zeros_like(agent.get_action(state)) while not terminated: - if agent is None: - action = env.action_space.sample() - else: - # FIXME: you know - action = agent.get_action(obs.reshape(1, -1))[0] - new_obs, reward, terminated, _, _ = env.step(action) + try: + obs = env.render() if collect_obs else None + except RuntimeError: + # FIXME: hot-fix for Crafter env to work + warnings.warn("Cannot render environment, using state instead") + obs = state + + # FIXME: works only for crafter + yield EnvStep(obs=torch.from_numpy(state), + action=torch.Tensor(action).squeeze(), + reward=reward, + is_finished=terminated, + is_first=is_first) + is_first = False + + with torch.no_grad(): + action = agent.get_action(state) + + state, reward, terminated = unpack(env.step(action)) + + +def collect_rollout(env: Env, + agent: t.Optional[RlAgent] = None, + collect_obs: bool = False) -> Rollout: + s, a, r, t, f, additional = [], [], [], [], [], defaultdict(list) + + if agent is None: + agent = RandomAgent(env) + + for step in iter_rollout(env, agent, collect_obs): + obs, action, reward, terminated, first, add = unpack(step) s.append(obs) a.append(action) r.append(reward) - n.append(new_obs) - f.append(terminated) - obs = new_obs - return np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f) + t.append(terminated) + f.append(first) + for k, v in add.items(): + additional[k].append(v) + + return Rollout(torch.stack(s), torch.stack(a).reshape(len(a), -1), + torch.Tensor(r).float(), torch.Tensor(t), torch.Tensor(f), + {k: torch.stack(v) + for k, v in additional.items()}) + -def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None) -> t.List[t.Tuple[States, Actions, Rewards, States, TerminationFlag]]: +def collect_rollout_num(env: Env, + num: int, + agent: t.Optional[t.Any] = None, + collect_obs: bool = False) -> t.List[Rollout]: + # TODO: paralelyze rollouts = [] for _ in range(num): - rollouts.append(collect_rollout(env, agent)) + rollouts.append(collect_rollout(env, agent, collect_obs)) return rollouts -def fillup_replay_buffer(env: gym.Env, rep_buffer: ReplayBuffer, num: int): +def fillup_replay_buffer(env: Env, rep_buffer: ReplayBuffer, num: int, agent: t.Optional[RlAgent] = None): + # TODO: paralelyze while not rep_buffer.can_sample(num): - s, a, r, n, f = collect_rollout(env) - rep_buffer.add_rollout(s, a, r, n, f) + rep_buffer.add_rollout(collect_rollout(env, agent=agent, collect_obs=False)) diff --git a/rl_sandbox/utils/schedulers.py b/rl_sandbox/utils/schedulers.py new file mode 100644 index 0000000..a68876c --- /dev/null +++ b/rl_sandbox/utils/schedulers.py @@ -0,0 +1,25 @@ +from abc import ABCMeta + +import numpy as np + +class Scheduler(metaclass=ABCMeta): + def step(self) -> float: + ... + +class LinearScheduler(Scheduler): + def __init__(self, initial_value, final_value, duration): + self._init = initial_value + self._final = final_value + self._dur = duration - 1 + self._curr_t = 0 + + @property + def val(self) -> float: + if self._curr_t >= self._dur: + return self._final + return np.interp([self._curr_t], [0, self._dur], [self._init, self._final]).item() + + def step(self) -> float: + val = self.val + self._curr_t += 1 + return val diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py new file mode 100644 index 0000000..37a3bc6 --- /dev/null +++ b/rl_sandbox/vision/dino.py @@ -0,0 +1,360 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Taken from YangtaoWANG95/TokenCut/unsupervised_saliency_detection/dino.py + +Copied from Dino repo. https://github.com/facebookresearch/dino +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial +import warnings + +import torch +import torch.nn as nn + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + feat_qkv = self.qkv(x) + qkv = feat_qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, (attn, feat_qkv) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + + +def vit_small(patch_size=16, img_size=[224], **kwargs): + model = VisionTransformer( + img_size=img_size, + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, img_size=[224], **kwargs): + model = VisionTransformer( + img_size=img_size, + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + + + +class ViTFeat(nn.Module): + """ Vision Transformer """ + def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16, img_size=[224]): + super().__init__() + if vit_arch == 'base' : + self.model = vit_base(patch_size=patch_size, num_classes=0, img_size=img_size) + + else : + self.model = vit_small(patch_size=patch_size, num_classes=0, img_size=img_size) + + self.feat_dim = feat_dim + self.vit_feat = vit_feat + self.patch_size = patch_size + +# state_dict = torch.load(pretrained_pth, map_location="cpu") + state_dict = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com"+pretrained_pth) + self.model.load_state_dict(state_dict, strict=True) + print('Loading weight from {}'.format(pretrained_pth)) + + + def forward(self, img) : + feat_out = {} + + + # Forward pass in the model + with torch.no_grad() : + h, w = img.shape[2], img.shape[3] + feat_h, feat_w = h // self.patch_size, w // self.patch_size + attentions = self.model.get_last_selfattention(img) + attentions, feat_qkv = attentions + bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] + qkv = ( + feat_qkv + .reshape(bs, nb_token, 3, nb_head, -1) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k.transpose(1, 2).reshape(bs, nb_token, -1) + q = q.transpose(1, 2).reshape(bs, nb_token, -1) + v = v.transpose(1, 2).reshape(bs, nb_token, -1) + + # Modality selection + if self.vit_feat == "k": + feats = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + elif self.vit_feat == "q": + feats = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + elif self.vit_feat == "v": + feats = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + elif self.vit_feat == "kqv": + k = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + q = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + v = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + feats = torch.cat([k, q, v], dim=1) + return feats + + +if __name__ == "__main__": + model = ViTFeat('/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', 64, 'base', 'k', patch_size=8) + img = torch.FloatTensor(4, 3, 64, 64) + # Forward pass in the model + feat = model(img) + print (feat[0].shape) diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py new file mode 100644 index 0000000..21975ea --- /dev/null +++ b/rl_sandbox/vision/slot_attention.py @@ -0,0 +1,299 @@ +import torch +import typing as t +from torch import nn +import torch.nn.functional as F +from jaxtyping import Float +import torchvision as tv +from tqdm import tqdm +import numpy as np + +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.utils.logger import Logger + +class SlotAttention(nn.Module): + def __init__(self, num_slots: int, n_dim: int, n_iter: int, use_prev_slots: bool): + super().__init__() + + self.n_slots = num_slots + self.n_iter = n_iter + self.n_dim = n_dim + self.scale = self.n_dim**(-1/2) + self.epsilon = 1e-8 + + self.use_prev_slots = use_prev_slots + if use_prev_slots: + self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + else: + self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) + nn.init.xavier_uniform_(self.slots_logsigma) + + self.slots_proj = nn.Linear(n_dim, n_dim, bias=False) + self.slots_proj_2 = nn.Sequential( + nn.Linear(n_dim, n_dim*4), + nn.ReLU(inplace=True), + nn.Linear(n_dim*4, n_dim), + ) + self.slots_norm = nn.LayerNorm(self.n_dim) + self.slots_norm_2 = nn.LayerNorm(self.n_dim) + self.slots_reccur = nn.GRUCell(input_size=self.n_dim, hidden_size=self.n_dim) + + self.inputs_proj = nn.Linear(n_dim, n_dim*2, bias=False) + self.inputs_norm = nn.LayerNorm(self.n_dim) + self.prev_slots = None + + def generate_initial(self, batch: int): + mu = self.slots_mu.expand(batch, self.n_slots, -1) + sigma = self.slots_logsigma.exp().expand(batch, self.n_slots, -1) + slots = mu + sigma * torch.randn(mu.shape, device=mu.device) + return slots + + def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']]) -> Float[torch.Tensor, 'batch num_slots n_dim']: + batch, _, _ = X.shape + k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) + + if prev_slots is None: + slots = self.generate_initial(batch) + self.prev_slots = slots.clone() + else: + slots = prev_slots + + self.last_attention = None + + for _ in range(self.n_iter): + slots_prev = slots + slots = self.slots_norm(slots) + q = self.slots_proj(slots) + + attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k).float(), dim=1) + self.epsilon + attn = attn / attn.sum(dim=-1, keepdim=True) + + self.last_attention = attn + + updates = torch.einsum('bjd,bij->bid', v, attn) + slots = self.slots_reccur(updates.reshape(-1, self.n_dim), slots_prev.reshape(-1, self.n_dim)).reshape(batch, self.n_slots, self.n_dim) + slots = slots + self.slots_proj_2(self.slots_norm_2(slots)) + return slots + +def build_grid(resolution): + ranges = [np.linspace(0., 1., num=res) for res in resolution] + grid = np.meshgrid(*ranges, sparse=False, indexing="ij") + grid = np.stack(grid, axis=-1) + grid = np.reshape(grid, [resolution[0], resolution[1], -1]) + grid = np.expand_dims(grid, axis=0) + grid = grid.astype(np.float32) + return np.concatenate([grid, 1.0 - grid], axis=-1) + + +class PositionalEmbedding(nn.Module): + def __init__(self, n_dim: int, res: t.Tuple[int, int], channel_last=False): + super().__init__() + self.n_dim = n_dim + self.proj = nn.Linear(4, n_dim) + self.channel_last = channel_last + self.register_buffer('grid', torch.from_numpy(build_grid(res))) + + def forward(self, X) -> torch.Tensor: + if self.channel_last: + return X + self.proj(self.grid) + else: + return X + self.proj(self.grid).permute(0, 3, 1, 2) + +class SlottedAutoEncoder(nn.Module): + def __init__(self, num_slots: int, n_iter: int, dino_inp_size: int = 224): + super().__init__() + in_channels = 3 + self.n_dim = 128 + self.lat_dim = int(self.n_dim**0.5) + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, self.n_dim, kernel_size=5, padding='same'), + nn.ReLU(inplace=True), + ) + + self.mlp = nn.Sequential( + nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim) + ) + + self.dino_inp_size = dino_inp_size + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) + self.vit_patch_num = self.dino_inp_size // self.dino_vit.patch_size + self.vit_feat = self.dino_vit.feat_dim + + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (7, 7)) + self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) + self.slot_attention = SlotAttention(num_slots, self.n_dim, n_iter) + self.img_decoder = nn.Sequential( # Dx8x8 -> (3+1)x64x64 + nn.ConvTranspose2d(self.n_dim, 48, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(48, 96, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(96, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(192, 4, kernel_size=3, stride=(1, 1), padding=1), + nn.ReLU(inplace=True), + ) + + self.vit_decoder = nn.Sequential( # Dx1x1 -> (384+1)x14x14 + nn.ConvTranspose2d(self.n_dim, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, self.vit_feat, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(self.vit_feat, self.vit_feat*2, kernel_size=3, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(self.vit_feat*2, self.vit_feat+1, kernel_size=3, stride=(1, 1), padding=1), + ) + + # self.vit_decoder_mlp = nn.Sequential( + # nn.Linear(self.n_dim, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, self.vit_feat+1), + # nn.ReLU(inplace=True) + # ) + + def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']] = None) -> t.Dict[str, torch.Tensor]: + features = self.encoder(X) # -> batch D h w + features_with_pos_enc = self.positional_augmenter_inp(features) # -> batch D h w + + resize = tv.transforms.Resize(self.dino_inp_size, antialias=True) + + batch, seq, _, _ = X.shape + vit_features = self.dino_vit(resize(X)) + vit_features = vit_features.reshape(batch, -1, self.vit_patch_num, self.vit_patch_num) + + pre_slot_features = self.mlp(features_with_pos_enc.permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)) + + slots = self.slot_attention(pre_slot_features, prev_slots) # -> batch num_slots D + slots_grid = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim).permute(0, 3, 1, 2) + + # slots_with_vit_pos_enc = self.positional_augmenter_vit_dec(slots_grid.flatten(2, 3).repeat((1, 1, 196)).reshape(-1, self.n_dim, self.lat_dim, self.lat_dim)).flatten(2, 3) + # decoded_features, vit_masks =self.vit_decoder_mlp(slots_with_vit_pos_enc).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) + + decoded_features, vit_masks = self.vit_decoder(slots_grid).permute(0, 2, 3, 1).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) + vit_mask = F.softmax(vit_masks, dim=1) + + rec_features = (decoded_features * vit_mask).sum(dim=1) + + slots_grid = slots_grid.repeat((1, 1, 8, 8)) # -> batch*num_slots D sqrt(D) sqrt(D) + slots_with_pos_enc = self.positional_augmenter_dec(slots_grid) + + decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *np.array(X.shape[2:]), 4).split([3, 1], dim=-1) + img_mask = F.softmax(masks, dim=1) + + decoded_imgs = decoded_imgs * img_mask + rec_img = torch.sum(decoded_imgs, dim=1) + return { + 'rec_img': rec_img.permute(0, 3, 1, 2), + 'img_per_slot': decoded_imgs.permute(0, 1, 4, 2, 3), + 'vit_mask': vit_mask, + 'vit_rec_loss': F.mse_loss(rec_features.permute(0, 3, 1, 2), vit_features), + 'slots': slots + } + +if __name__ == '__main__': + device = 'cuda' + debug = False + ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), + tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + train_data = tv.datasets.ImageFolder('~/rl_sandbox/crafter_data/', transform=ToTensor) + if debug: + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=4, + prefetch_factor=1, + shuffle=False, + num_workers=2) + else: + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=32, + shuffle=True, + num_workers=8) + + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + comment = "Added vit masks logging, lambda=0.1, return old dino".replace(" ", "_") + logger = Logger(None if debug else 'tensorboard', message=comment, log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}_{comment}") + + number_of_slots = 7 + slots_iter_num = 3 + + total_steps = 5e5 + warmup_steps = 1e4 + decay_rate = 0.5 + decay_steps = 1e5 + val_every = 1e4 + + model = SlottedAutoEncoder(number_of_slots, slots_iter_num).to(device) + # model = torch.compile(model) + optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) + lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) + lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) + # lr_decay_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate**(1/decay_steps)) + lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) + + global_step = 0 + prev_global_step = 0 + epoch = 0 + pbar = tqdm(total=total_steps, desc='Training') + while global_step < total_steps: + for sample_num, (img, target) in enumerate(train_data_loader): + res = model(img.to(device)) + recovered_img, vit_rec_loss = res['rec_img'], res['vit_rec_loss'] + + reg_loss = F.mse_loss(img.to(device), recovered_img) + + lambda_ = 0.1 + loss = lambda_ * reg_loss + (1 - lambda_) * vit_rec_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + logger.add_scalar('train/img_rec_loss', reg_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/vit_rec_loss', vit_rec_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/loss', loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + pbar.update(1) + global_step += len(train_data_loader) + + epoch += 1 + logger.add_scalar('epoch', epoch, epoch) + + if global_step - prev_global_step > val_every: + prev_global_step = global_step + else: + continue + + for i in range(3): + img, target = next(iter(train_data_loader)) + res = model(img.to(device)) + recovered_img, imgs_per_slot, vit_mask = res['rec_img'], res['img_per_slot'], res['vit_mask'] + upscale = tv.transforms.Resize(64, antialias=True) + unnormalize = tv.transforms.Compose([ + tv.transforms.Normalize((0, 0, 0), (1/0.229, 1/0.224, 1/0.225)), + tv.transforms.Normalize((-0.485, -0.456, -0.406), (1., 1., 1.)) + ]) + logger.add_image(f'val/example_image', unnormalize(img.cpu().detach()[0]), epoch*3 + i) + logger.add_image(f'val/example_image_rec', unnormalize(recovered_img.cpu().detach()[0]), epoch*3 + i) + per_slot_img = unnormalize(imgs_per_slot.cpu().detach())[0].permute((1, 2, 0, 3)).flatten(2, 3) + logger.add_image(f'val/example_image_slot_rec', per_slot_img, epoch*3 + i) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * img.to(device).unsqueeze(1))[0].permute(1, 2, 0, 3).flatten(2, 3) + logger.add_image(f'val/example_vit_slot_mask', per_slot_vit/upscaled_mask.max(), epoch*3 + i) + diff --git a/rl_sandbox/vision/vae.py b/rl_sandbox/vision/vae.py new file mode 100644 index 0000000..8a8d133 --- /dev/null +++ b/rl_sandbox/vision/vae.py @@ -0,0 +1,177 @@ +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +import torchvision +from PIL.Image import Image +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + + +class ResBlock(nn.Module): + + def __init__(self, in_channels, hidden_units=256): + super().__init__() + + self.block = nn.Sequential( + nn.ReLU(), nn.Conv2d(in_channels, hidden_units, kernel_size=3, + padding='same'), nn.ReLU(inplace=True), + nn.Conv2d(hidden_units, in_channels, kernel_size=1, padding='same')) + + def forward(self, X): + output = self.block(X) + return X + output + + +class VAE(nn.Module): + + def __init__(self, latent_dim=3, kl_weight=2.5e-4): + super().__init__() + self.latent_dim = latent_dim + self.kl_weight = kl_weight + + in_channels = 3 + out_channels = 128 + + self.encoder = nn.Sequential( + nn.BatchNorm2d(3), + nn.Conv2d(in_channels, out_channels // 2, kernel_size=4, stride=2, + padding=1), # 32 -> 16 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.Conv2d(out_channels // 2, out_channels, kernel_size=4, stride=2, + padding=1), # 16 -> 8 + nn.LeakyReLU(inplace=True), + ResBlock(out_channels), + ResBlock(out_channels), + nn.Conv2d(out_channels, 4, 1), # 4x8x8 + nn.Flatten()) + + self.f_mu = nn.Linear(256, self.latent_dim) + self.f_log_sigma = nn.Linear(256, self.latent_dim) + + self.decoder_1 = nn.Sequential( + nn.Linear(self.latent_dim, 256), + nn.LeakyReLU(inplace=True), + ) + + self.decoder_2 = nn.Sequential( + nn.Conv2d(4, out_channels, 1), + ResBlock(out_channels), + ResBlock(out_channels), + nn.BatchNorm2d(out_channels), + nn.ConvTranspose2d(out_channels, + out_channels // 2, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.ConvTranspose2d(out_channels // 2, + in_channels, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + ) + + def forward(self, X): + z_h = self.encoder(X) + + z_mu = self.f_mu(z_h) + z_log_sigma = self.f_log_sigma(z_h) + + device = next(self.f_mu.parameters()).device + z = z_mu + z_log_sigma.exp() * torch.rand_like(z_mu).to(device) + + x_h_1 = self.decoder_1(z) + x_h = self.decoder_2(x_h_1.view(-1, 4, 8, 8)) + return x_h, z_mu, z_log_sigma + + def calculate_loss(self, x, x_h, z_mu, z_log_sigma) -> dict[str, torch.Tensor]: + # loss = log p(x | z) + KL(q(z) || p(z)) + # p(z) = N(0, 1) + L_rec = torch.nn.MSELoss() + + loss_kl = -1 * torch.mean(torch.sum( + z_log_sigma + 0.5 * (1 - z_log_sigma.exp()**2 - z_mu**2), dim=1), + dim=0) + loss_rec = L_rec(x, x_h) + + return { + 'loss': loss_rec + self.kl_weight * loss_kl, + 'loss_rec': loss_rec, + 'loss_kl': loss_kl + } + + +def image_preprocessing(img: Image): + return torchvision.transforms.ToTensor()(img) + + +if __name__ == "__main__": + import torch.multiprocessing + + # fix for "unable to open shared memory on mac" + torch.multiprocessing.set_sharing_strategy('file_system') + + train_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=True, + transform=image_preprocessing) + test_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=False, + transform=image_preprocessing) + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=128, + shuffle=True, + num_workers=8) + test_data_loader = torch.utils.data.DataLoader(test_data, + batch_size=128, + shuffle=True, + num_workers=8) + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}") + + device = 'mps' + model = VAE(latent_dim=256).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) + + for epoch in tqdm(range(100)): + + logger.add_scalar('epoch', epoch, epoch) + + for sample_num, (img, target) in enumerate(train_data_loader): + recovered_img, z_mu, z_log_sigma = model(img.to(device)) + + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, + z_log_sigma) + loss = losses['loss'] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for loss_kind in losses: + logger.add_scalar(f'train/{loss_kind}', losses[loss_kind].cpu().detach(), + epoch * len(train_data_loader) + sample_num) + + val_losses = defaultdict(list) + for img, target in test_data_loader: + recovered_img, z_mu, z_log_sigma = model(img.to(device)) + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, + z_log_sigma) + + for loss_kind in losses: + val_losses[loss_kind].append(losses[loss_kind].cpu().detach()) + + for loss_kind in val_losses: + logger.add_scalar(f'val/{loss_kind}', np.mean(val_losses[loss_kind]), epoch) + logger.add_image(f'val/example_image', img.cpu().detach()[0], epoch) + logger.add_image(f'val/example_image_rec', + recovered_img.cpu().detach()[0], epoch) diff --git a/rl_sandbox/vision/vq_vae.py b/rl_sandbox/vision/vq_vae.py new file mode 100644 index 0000000..c1efdc4 --- /dev/null +++ b/rl_sandbox/vision/vq_vae.py @@ -0,0 +1,165 @@ +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +import torchvision +from PIL.Image import Image +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from rl_sandbox.vision.vae import ResBlock + + +class VQ_VAE(nn.Module): + + def __init__(self, latent_space_size, latent_dim, beta=0.25): + super().__init__() + # amount of the discrete vectors + self.latent_space_size = latent_space_size + # dimensionality of each category + self.latent_dim = latent_dim + self.beta = beta + + self.latent_space = torch.nn.Parameter( + torch.empty(size=(self.latent_space_size, self.latent_dim))) + torch.nn.init.kaiming_uniform_(self.latent_space) + + in_channels = 3 + out_channels = 128 + + self.encoder = nn.Sequential( + nn.BatchNorm2d(3), + nn.Conv2d(in_channels, out_channels // 2, kernel_size=4, stride=2, + padding=1), # 32 -> 16 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.Conv2d(out_channels // 2, out_channels, kernel_size=4, stride=2, + padding=1), # 16 -> 8 + nn.LeakyReLU(inplace=True), + ResBlock(out_channels), + ResBlock(out_channels), + nn.Conv2d(out_channels, latent_dim, 1), # Dx8x8 + ) + + self.decoder = nn.Sequential( + nn.Conv2d(latent_dim, out_channels, 1), + ResBlock(out_channels), + ResBlock(out_channels), + nn.BatchNorm2d(out_channels), + nn.ConvTranspose2d(out_channels, + out_channels // 2, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.ConvTranspose2d(out_channels // 2, + in_channels, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + ) + + def quantize(self, z): + # z <- BxDxHxW + # Pytorch BUG: https://github.com/pytorch/pytorch/issues/84206 + # .to(memory_format=torch.contiguous_format) should be used instead of .contigious() on mac m1 + latents = torch.permute(z, (0, 2, 3, 1)).to(memory_format=torch.contiguous_format) # BxHxWxD + flatten = latents.view(-1, self.latent_dim) # BHWxD + + # use the property that (a - b)^2 = a^2 - 2ab + b^2 + l2_dist = torch.sum(flatten**2, dim=1, keepdim=True) - 2 * ( + flatten @ self.latent_space.T) + torch.sum(self.latent_space**2, dim=1) # BHWxK + + ks = torch.argmin(l2_dist, dim=1) + + flatten_quantized_latents = torch.index_select(self.latent_space, 0, ks) # BHWxD + e = flatten_quantized_latents.view(latents.shape).permute((0, 3, 1, 2)).to(memory_format=torch.contiguous_format) + return e + (z - z.detach()) + + + def forward(self, X): + z = self.encoder(X) + e = self.quantize(z) + x_h = self.decoder(e) + return x_h, z, e + + def calculate_loss(self, x, x_h, z, e) -> dict[str, torch.Tensor]: + # loss = log p(x | z) + || stop_grad(e) - z ||_2 + beta *|| e - stop_grad(z) ||_2 + L_rec = torch.nn.MSELoss() + + loss_reg = torch.norm(e.detach() - z, + p=2) + self.beta * torch.norm(e - z.detach(), p=2) + loss_rec = L_rec(x, x_h) + + return {'loss': loss_rec + loss_reg, 'loss_rec': loss_rec, 'loss_reg': loss_reg} + + +def image_preprocessing(img: Image): + return torchvision.transforms.ToTensor()(img) + + +if __name__ == "__main__": + # fix for "unable to open shared memory on mac" + torch.multiprocessing.set_sharing_strategy('file_system') + + train_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=True, + transform=image_preprocessing) + test_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=False, + transform=image_preprocessing) + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=128, + shuffle=True, + num_workers=8) + test_data_loader = torch.utils.data.DataLoader(test_data, + batch_size=128, + shuffle=True, + num_workers=8) + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}") + + device = 'mps' + model = VQ_VAE(latent_space_size=256, latent_dim=1).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) + + for epoch in tqdm(range(100)): + + logger.add_scalar('epoch', epoch, epoch) + + for sample_num, (img, target) in enumerate(train_data_loader): + recovered_img, z, e = model(img.to(device)) + + losses = model.calculate_loss(img.to(device), recovered_img, z, e) + loss = losses['loss'] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for loss_kind in losses: + logger.add_scalar(f'train/{loss_kind}', losses[loss_kind].cpu().detach(), + epoch * len(train_data_loader) + sample_num) + + val_losses = defaultdict(list) + for img, target in test_data_loader: + recovered_img, z_mu, z_log_sigma = model(img.to(device)) + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, + z_log_sigma) + + for loss_kind in losses: + val_losses[loss_kind].append(losses[loss_kind].cpu().detach()) + + for loss_kind in val_losses: + logger.add_scalar(f'val/{loss_kind}', np.mean(val_losses[loss_kind]), epoch) + logger.add_image(f'val/example_image', img.cpu().detach()[0], epoch) + logger.add_image(f'val/example_image_rec', + recovered_img.cpu().detach()[0], epoch) diff --git a/tests/test_replay_buffer.py b/tests/test_replay_buffer.py deleted file mode 100644 index beea74a..0000000 --- a/tests/test_replay_buffer.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import random -from pytest import fixture - -from rl_sandbox.utils.replay_buffer import ReplayBuffer - -@fixture -def rep_buf(): - return ReplayBuffer() - -def test_creation(rep_buf): - assert len(rep_buf.states) == 0 - -def test_adding(rep_buf): - s = np.ones((3, 8)) - a = np.ones((3, 3)) - r = np.ones((3)) - rep_buf.add_rollout(s, a, r) - - assert len(rep_buf.states) == 3 - assert len(rep_buf.actions) == 3 - assert len(rep_buf.rewards) == 3 - - s = np.zeros((3, 8)) - a = np.zeros((3, 3)) - r = np.zeros((3)) - rep_buf.add_rollout(s, a, r) - - assert len(rep_buf.states) == 6 - assert len(rep_buf.actions) == 6 - assert len(rep_buf.rewards) == 6 - -def test_can_sample(rep_buf): - assert rep_buf.can_sample(1) == False - - s = np.ones((3, 8)) - a = np.ones((3, 3)) - r = np.ones((3)) - rep_buf.add_rollout(s, a, r) - - assert rep_buf.can_sample(5) == False - assert rep_buf.can_sample(1) == True - - rep_buf.add_rollout(s, a, r) - - assert rep_buf.can_sample(5) == True - -def test_sampling(rep_buf): - for i in range(1, 5): - rep_buf.add_rollout(np.ones((1,3)), np.ones((1,2)), i*np.ones((1))) - - random.seed(42) - _, _, r = rep_buf.sample(3) - assert (r == [3, 2, 4]).all()