From e4cfa1deb0fdffb13e4aa7a25e08be9f15d0a17f Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 17 Oct 2022 14:25:53 +0100 Subject: [PATCH 001/106] Added visualization for validation --- config/config.yaml | 4 ++- main.py | 34 ++++++++++++++-------- pyproject.toml | 2 ++ rl_sandbox/agents/dqn_agent.py | 4 +-- rl_sandbox/utils/replay_buffer.py | 39 +++++++++++++++++--------- rl_sandbox/utils/rollout_generation.py | 31 ++++++++++---------- 6 files changed, 69 insertions(+), 45 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 4f14d23..3a73f6f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -5,8 +5,10 @@ env: CartPole-v1 seed: 42 training: - epochs: 5000 + epochs: 100 + steps_per_epoch: 1000 batch_size: 128 validation: rollout_num: 5 + visualize: true diff --git a/main.py b/main.py index 0936874..ad22cf5 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ from rl_sandbox.agents.dqn_agent import DqnAgent from rl_sandbox.utils.replay_buffer import ReplayBuffer from rl_sandbox.utils.rollout_generation import collect_rollout, fillup_replay_buffer, collect_rollout_num +from rl_sandbox.utils.visualization import Renderer from torch.utils.tensorboard.writer import SummaryWriter import numpy as np @@ -15,6 +16,7 @@ def main(cfg: DictConfig): print(OmegaConf.to_yaml(cfg)) env = gym.make(cfg.env) + visualized_env = gym.make(cfg.env, render_mode='rgb_array_list') buff = ReplayBuffer() # FIXME: samples should be also added afterwards @@ -31,18 +33,26 @@ def main(cfg: DictConfig): writer = SummaryWriter() for epoch_num in range(cfg.training.epochs): - # TODO: add exploration and adding data to buffer at each step - - s, a, r, n, f = buff.sample(cfg.training.batch_size) - - loss = agent.train(s, a, r, n, f) - writer.add_scalar('train/loss', loss, epoch_num) - - if epoch_num % 100 == 0: - rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) - average_len = np.mean(list(map(lambda x: len(x[0]), rollouts))) - writer.add_scalar('val/average_len', average_len, epoch_num) - + # TODO: add exploration annealing + for step in range(cfg.training.steps_per_epoch): + global_step = epoch_num * cfg.training.steps_per_epoch + step + # TODO: add exploration and adding data to buffer at each step + s, a, r, n, f = buff.sample(cfg.training.batch_size) + + loss = agent.train(s, a, r, n, f) + writer.add_scalar('train/loss', loss, global_step) + + ### Validation + rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) + average_len = np.mean(list(map(lambda x: len(x.states), rollouts))) + writer.add_scalar('val/average_len', average_len, epoch_num) + + if cfg.validation.visualize: + rollouts = collect_rollout_num(visualized_env, cfg.validation.visualized_rollout_num, agent, save_obs=True) + + for rollout in rollouts: + video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) + writer.add_video('val/visualization', video, epoch_num) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index e986927..e75a807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,5 @@ python = "^3.10" numpy = '*' nptyping = '*' gym = "^0.26.1" +pygame = '*' +moviepy = '*' diff --git a/rl_sandbox/agents/dqn_agent.py b/rl_sandbox/agents/dqn_agent.py index 13ba74d..96787b6 100644 --- a/rl_sandbox/agents/dqn_agent.py +++ b/rl_sandbox/agents/dqn_agent.py @@ -4,7 +4,7 @@ from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, - States, TerminationFlag) + States, TerminationFlags) class DqnAgent(RlAgent): @@ -24,7 +24,7 @@ def __init__(self, actions_num: int, def get_action(self, obs: State) -> Action: return np.array(torch.argmax(self.value_func(torch.from_numpy(obs)), dim=1)) - def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlag): + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): # Bellman error: MSE( (r + gamma * max_a Q(S_t+1, a)) - Q(s_t, a) ) # check for is finished diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index ffdb8d4..e635f0d 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -1,20 +1,31 @@ import random import typing as t from collections import deque +from dataclasses import dataclass import numpy as np from nptyping import Bool, Int, Float, NDArray, Shape +Observation = NDArray[Shape["*,*,3"],Int] State = NDArray[Shape["*"],Float] Action = NDArray[Shape["*"],Int] +Observations = NDArray[Shape["*,*,*,3"],Int] States = NDArray[Shape["*,*"],Float] Actions = NDArray[Shape["*,*"],Int] Rewards = NDArray[Shape["*"],Float] -TerminationFlag = NDArray[Shape["*"],Bool] +TerminationFlags = NDArray[Shape["*"],Bool] + +@dataclass +class Rollout: + states: States + actions: Actions + rewards: Rewards + next_states: States + is_finished: TerminationFlags + observations: t.Optional[Observations] = None -# ReplayBuffer consists of next triplets: (s, a, r) class ReplayBuffer: def __init__(self, max_len=10_000): self.max_len = max_len @@ -23,24 +34,24 @@ def __init__(self, max_len=10_000): self.rewards: Rewards = np.array([]) self.next_states: States = np.array([]) - def add_rollout(self, s: States, a: Actions, r: Rewards, n: States, f: TerminationFlag): + def add_rollout(self, rollout: Rollout): if len(self.states) == 0: - self.states = s - self.actions = a - self.rewards = r - self.next_states = n - self.is_finished = f + self.states = rollout.states + self.actions = rollout.actions + self.rewards = rollout.rewards + self.next_states = rollout.next_states + self.is_finished = rollout.is_finished else: - self.states = np.concatenate([self.states, s]) - self.actions = np.concatenate([self.actions, a]) - self.rewards = np.concatenate([self.rewards, r]) - self.next_states = np.concatenate([self.next_states, n]) - self.is_finished = np.concatenate([self.is_finished, f]) + self.states = np.concatenate([self.states, rollout.states]) + self.actions = np.concatenate([self.actions, rollout.actions]) + self.rewards = np.concatenate([self.rewards, rollout.rewards]) + self.next_states = np.concatenate([self.next_states, rollout.next_states]) + self.is_finished = np.concatenate([self.is_finished, rollout.is_finished]) def can_sample(self, num: int): return len(self.states) >= num - def sample(self, num: int) -> t.Tuple[States, Actions, Rewards, States, TerminationFlag]: + def sample(self, num: int) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: indeces = list(range(len(self.states))) random.shuffle(indeces) indeces = indeces[:num] diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index bdbd503..c647285 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -3,39 +3,38 @@ import gym import numpy as np -from rl_sandbox.utils.replay_buffer import (Actions, ReplayBuffer, Rewards, - States, TerminationFlag) +from rl_sandbox.utils.replay_buffer import (ReplayBuffer, Rollout) - -def collect_rollout(env: gym.Env, agent: t.Optional[t.Any] = None) -> t.Tuple[States, Actions, Rewards, States, TerminationFlag]: +def collect_rollout(env: gym.Env, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> Rollout: s, a, r, n, f = [], [], [], [], [] - obs, _ = env.reset() + state, _ = env.reset() terminated = False while not terminated: if agent is None: action = env.action_space.sample() else: - # FIXME: you know - action = agent.get_action(obs.reshape(1, -1))[0] - new_obs, reward, terminated, _, _ = env.step(action) - s.append(obs) + # FIXME: move reshaping inside DqnAgent + action = agent.get_action(state.reshape(1, -1))[0] + new_state, reward, terminated, _, _ = env.step(action) + s.append(state) a.append(action) r.append(reward) - n.append(new_obs) + n.append(new_state) f.append(terminated) - obs = new_obs - return np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f) + state = new_state + + obs = np.stack(list(env.render())) if save_obs else None + return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) -def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None) -> t.List[t.Tuple[States, Actions, Rewards, States, TerminationFlag]]: +def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> t.List[Rollout]: rollouts = [] for _ in range(num): - rollouts.append(collect_rollout(env, agent)) + rollouts.append(collect_rollout(env, agent, save_obs)) return rollouts def fillup_replay_buffer(env: gym.Env, rep_buffer: ReplayBuffer, num: int): while not rep_buffer.can_sample(num): - s, a, r, n, f = collect_rollout(env) - rep_buffer.add_rollout(s, a, r, n, f) + rep_buffer.add_rollout(collect_rollout(env)) From b2d2769b8e7f06380cb311f6cfdb5534a0a16f5f Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 17 Oct 2022 18:36:41 +0100 Subject: [PATCH 002/106] working dqn --- config/agent/dqn_agent.yaml | 6 ++-- config/config.yaml | 4 +-- main.py | 40 +++++++++++++++++--------- rl_sandbox/agents/dqn_agent.py | 25 +++++++++------- rl_sandbox/utils/replay_buffer.py | 13 ++++++++- rl_sandbox/utils/rollout_generation.py | 6 ++-- rl_sandbox/utils/schedulers.py | 21 ++++++++++++++ tests/test_linear_scheduler.py | 15 ++++++++++ 8 files changed, 97 insertions(+), 33 deletions(-) create mode 100644 rl_sandbox/utils/schedulers.py create mode 100644 tests/test_linear_scheduler.py diff --git a/config/agent/dqn_agent.yaml b/config/agent/dqn_agent.yaml index 327bf96..f8097cd 100644 --- a/config/agent/dqn_agent.yaml +++ b/config/agent/dqn_agent.yaml @@ -1,4 +1,4 @@ name: dqn -hidden_layer_size: 32 -num_layers: 2 -discount_factor: 0.99 +hidden_layer_size: 16 +num_layers: 1 +discount_factor: 0.98 diff --git a/config/config.yaml b/config/config.yaml index 3a73f6f..c3fc383 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -3,10 +3,10 @@ defaults: env: CartPole-v1 seed: 42 +device_type: cpu training: - epochs: 100 - steps_per_epoch: 1000 + epochs: 5000 batch_size: 128 validation: diff --git a/main.py b/main.py index ad22cf5..570b856 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from rl_sandbox.agents.dqn_agent import DqnAgent from rl_sandbox.utils.replay_buffer import ReplayBuffer from rl_sandbox.utils.rollout_generation import collect_rollout, fillup_replay_buffer, collect_rollout_num -from rl_sandbox.utils.visualization import Renderer +from rl_sandbox.utils.schedulers import LinearScheduler from torch.utils.tensorboard.writer import SummaryWriter import numpy as np @@ -19,7 +19,6 @@ def main(cfg: DictConfig): visualized_env = gym.make(cfg.env, render_mode='rgb_array_list') buff = ReplayBuffer() - # FIXME: samples should be also added afterwards fillup_replay_buffer(env, buff, cfg.training.batch_size) # INFO: currently supports only discrete action space @@ -27,32 +26,45 @@ def main(cfg: DictConfig): agent_name = agent_params.pop('name') agent = DqnAgent(obs_space_num=env.observation_space.shape[0], actions_num=env.action_space.n, + device_type=cfg.device_type, **agent_params, ) writer = SummaryWriter() + scheduler = LinearScheduler(0.9, 0.01, 5_000) + + global_step = 0 for epoch_num in range(cfg.training.epochs): - # TODO: add exploration annealing - for step in range(cfg.training.steps_per_epoch): - global_step = epoch_num * cfg.training.steps_per_epoch + step - # TODO: add exploration and adding data to buffer at each step + ### Training and exploration + state, _ = env.reset() + terminated = False + while not terminated: + if np.random.random() > scheduler.step(): + action = env.action_space.sample() + else: + action = agent.get_action(state) + new_state, reward, terminated, _, _ = env.step(action) + buff.add_sample(state, action, reward, new_state, terminated) + s, a, r, n, f = buff.sample(cfg.training.batch_size) loss = agent.train(s, a, r, n, f) writer.add_scalar('train/loss', loss, global_step) + global_step += 1 ### Validation - rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) - average_len = np.mean(list(map(lambda x: len(x.states), rollouts))) - writer.add_scalar('val/average_len', average_len, epoch_num) + if epoch_num % 100 == 0: + rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) + average_len = np.mean(list(map(lambda x: len(x.states), rollouts))) + writer.add_scalar('val/average_len', average_len, epoch_num) - if cfg.validation.visualize: - rollouts = collect_rollout_num(visualized_env, cfg.validation.visualized_rollout_num, agent, save_obs=True) + if cfg.validation.visualize: + rollouts = collect_rollout_num(visualized_env, 1, agent, save_obs=True) - for rollout in rollouts: - video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, epoch_num) + for rollout in rollouts: + video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) + writer.add_video('val/visualization', video, epoch_num) if __name__ == "__main__": diff --git a/rl_sandbox/agents/dqn_agent.py b/rl_sandbox/agents/dqn_agent.py index 96787b6..9ce2887 100644 --- a/rl_sandbox/agents/dqn_agent.py +++ b/rl_sandbox/agents/dqn_agent.py @@ -12,36 +12,39 @@ def __init__(self, actions_num: int, obs_space_num: int, hidden_layer_size: int, num_layers: int, - discount_factor: float): + discount_factor: float, + device_type: str = 'cpu'): self.gamma = discount_factor self.value_func = fc_nn_generator(obs_space_num, actions_num, hidden_layer_size, - num_layers) + num_layers).to(device_type) self.optimizer = torch.optim.Adam(self.value_func.parameters(), lr=1e-3) self.loss = torch.nn.MSELoss() def get_action(self, obs: State) -> Action: - return np.array(torch.argmax(self.value_func(torch.from_numpy(obs)), dim=1)) + return np.array(torch.argmax(self.value_func(torch.from_numpy(obs.reshape(1, -1)).to(device_type)), dim=1).detach().cpu())[0] def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): # Bellman error: MSE( (r + gamma * max_a Q(S_t+1, a)) - Q(s_t, a) ) # check for is finished - s = torch.from_numpy(s) - a = torch.from_numpy(a) - r = torch.from_numpy(r) - next = torch.from_numpy(next) - is_finished = torch.from_numpy(is_finished) + s = torch.from_numpy(s).to(device_type) + a = torch.from_numpy(a).to(device_type) + r = torch.from_numpy(r).to(device_type) + next = torch.from_numpy(next).to(device_type) + is_finished = torch.from_numpy(is_finished).to(device_type) + # TODO: normalize input + # TODO: double dqn with target network values = self.value_func(next) indeces = torch.argmax(values, dim=1) - x = r + (self.gamma * torch.gather(values, dim=1, index=indeces.unsqueeze(1)).squeeze(1)) * torch.logical_not(is_finished) + target = r + (self.gamma * torch.gather(values, dim=1, index=indeces.unsqueeze(1)).squeeze(1)) * torch.logical_not(is_finished) - loss = self.loss(x, torch.gather(self.value_func(s), dim=1, index=a).squeeze(1)) + loss = self.loss(torch.gather(self.value_func(s), dim=1, index=a).squeeze(1), target.detach()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() - return loss.detach() + return loss.detach().cpu() diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index e635f0d..8821301 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -27,7 +27,7 @@ class Rollout: class ReplayBuffer: - def __init__(self, max_len=10_000): + def __init__(self, max_len=2_000): self.max_len = max_len self.states: States = np.array([]) self.actions: Actions = np.array([]) @@ -48,6 +48,17 @@ def add_rollout(self, rollout: Rollout): self.next_states = np.concatenate([self.next_states, rollout.next_states]) self.is_finished = np.concatenate([self.is_finished, rollout.is_finished]) + if len(self.states) >= self.max_len: + self.states = self.states + self.actions = self.actions + self.rewards = self.rewards + self.next_states = self.next_states + self.is_finished = self.is_finished + + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + rollout = Rollout(np.array([s]), np.expand_dims(np.array([a]), 0), np.array([r], dtype=np.float32), np.array([n]), np.array([f])) + self.add_rollout(rollout) + def can_sample(self, num: int): return len(self.states) >= num diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index c647285..c6f65a6 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -12,11 +12,11 @@ def collect_rollout(env: gym.Env, agent: t.Optional[t.Any] = None, save_obs: boo terminated = False while not terminated: + # TODO: use gym.ActionWrapper instead if agent is None: action = env.action_space.sample() else: - # FIXME: move reshaping inside DqnAgent - action = agent.get_action(state.reshape(1, -1))[0] + action = agent.get_action(state) new_state, reward, terminated, _, _ = env.step(action) s.append(state) a.append(action) @@ -29,6 +29,7 @@ def collect_rollout(env: gym.Env, agent: t.Optional[t.Any] = None, save_obs: boo return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> t.List[Rollout]: + # TODO: paralelyze rollouts = [] for _ in range(num): rollouts.append(collect_rollout(env, agent, save_obs)) @@ -36,5 +37,6 @@ def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, def fillup_replay_buffer(env: gym.Env, rep_buffer: ReplayBuffer, num: int): + # TODO: paralelyze while not rep_buffer.can_sample(num): rep_buffer.add_rollout(collect_rollout(env)) diff --git a/rl_sandbox/utils/schedulers.py b/rl_sandbox/utils/schedulers.py new file mode 100644 index 0000000..d49adf2 --- /dev/null +++ b/rl_sandbox/utils/schedulers.py @@ -0,0 +1,21 @@ +from abc import ABCMeta + +import numpy as np + +class Scheduler(metaclass=ABCMeta): + def step(self) -> float: + ... + +class LinearScheduler(Scheduler): + def __init__(self, initial_value, final_value, duration): + self._init = initial_value + self._final = final_value + self._dur = duration - 1 + self._curr_t = 0 + + def step(self) -> float: + if self._curr_t >= self._dur: + return self._final + val = np.interp([self._curr_t], [0, self._dur], [self._init, self._final]) + self._curr_t += 1 + return val diff --git a/tests/test_linear_scheduler.py b/tests/test_linear_scheduler.py new file mode 100644 index 0000000..e39a66d --- /dev/null +++ b/tests/test_linear_scheduler.py @@ -0,0 +1,15 @@ +from rl_sandbox.utils.schedulers import LinearScheduler + +def test_linear_schedule(): + s = LinearScheduler(0, 10, 5) + assert s.step() == 0 + assert s.step() == 2.5 + assert s.step() == 5 + assert s.step() == 7.5 + assert s.step() == 10.0 + +def test_linear_schedule_after(): + s = LinearScheduler(0, 10, 5) + for _ in range(5): + s.step() + assert s.step() == 10.0 From 3a358e64c0c9723a5e56f67355c3a23097730349 Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 26 Oct 2022 23:30:19 +0100 Subject: [PATCH 003/106] Added VAE implementation --- rl_sandbox/vision/vae.py | 137 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 rl_sandbox/vision/vae.py diff --git a/rl_sandbox/vision/vae.py b/rl_sandbox/vision/vae.py new file mode 100644 index 0000000..3dc00f2 --- /dev/null +++ b/rl_sandbox/vision/vae.py @@ -0,0 +1,137 @@ +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +import torchvision +from PIL.Image import Image +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + + +class VAE(nn.Module): + def __init__(self, latent_dim=2, kl_weight=2.5e-4): + super().__init__() + self.latent_dim = latent_dim + self.kl_weight = kl_weight + + self.encoder = nn.Sequential( + nn.BatchNorm2d(1), + nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1), # 1x28x28 -> 8x13x13 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(8), + + nn.Conv2d(8, 32, kernel_size=3, stride=2), # 8x13x13 -> 32x6x6 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(32), + + nn.Conv2d(32, 8, kernel_size=1), # 32x6x6 -> 8x6x6 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(8), + + nn.Flatten(), # 8x6x6 -> 36*8 + ) + + self.f_mu = nn.Linear(288, self.latent_dim) + self.f_log_sigma = nn.Linear(288, self.latent_dim) + + self.decoder_1 = nn.Sequential( + nn.Linear(self.latent_dim, 288), + nn.LeakyReLU(inplace=True), + ) + + self.decoder_2 = nn.Sequential( + nn.BatchNorm2d(8), + nn.ConvTranspose2d(8, 32, kernel_size=1), + nn.LeakyReLU(inplace=True), + + nn.BatchNorm2d(32), + nn.ConvTranspose2d(32, 8, kernel_size=3, stride=2), + nn.LeakyReLU(inplace=True), + + nn.BatchNorm2d(8), + nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, output_padding=1), + nn.LeakyReLU(inplace=True), + ) + + def forward(self, X): + z_h = self.encoder(X) + + z_mu = self.f_mu(z_h) + z_log_sigma = self.f_log_sigma(z_h) + + z = z_mu + z_log_sigma.exp()*torch.rand_like(z_mu).to('mps') + + x_h_1 = self.decoder_1(z) + x_h = self.decoder_2(x_h_1.view(-1, 8, 6, 6)) + return x_h, z_mu, z_log_sigma + + def calculate_loss(self, x, x_h, z_mu, z_log_sigma) -> dict[str, torch.Tensor]: + # loss = log p(x | z) + KL(q(z) || p(z)) + # p(z) = N(0, 1) + L_rec = torch.nn.MSELoss() + + loss_kl = -1 * torch.mean(torch.sum(z_log_sigma + 0.5*(1 - z_log_sigma.exp()**2 - z_mu**2), dim=1), dim=0) + loss_rec = L_rec(x, x_h) + + return {'loss': loss_rec + self.kl_weight * loss_kl, 'loss_rec': loss_rec, 'loss_kl': loss_kl} + +def image_preprocessing(img: Image): + return torchvision.transforms.ToTensor()(img) + +if __name__ == "__main__": + train_mnist_data = torchvision.datasets.MNIST(str(Path()/'data'/'mnist'), + download=True, + train=True, + transform=image_preprocessing) + test_mnist_data = torchvision.datasets.MNIST(str(Path()/'data'/'mnist'), + download=True, + train=False, + transform=image_preprocessing) + train_data_loader = torch.utils.data.DataLoader(train_mnist_data, + batch_size=128, + shuffle=True, + num_workers=8) + test_data_loader = torch.utils.data.DataLoader(test_mnist_data, + batch_size=128, + shuffle=True, + num_workers=8) + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}") + + device = 'mps' + model = VAE().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + for epoch in tqdm(range(100)): + + logger.add_scalar('epoch', epoch, epoch) + + for sample_num, (img, target) in enumerate(train_data_loader): + recovered_img, z_mu, z_log_sigma = model(img.to(device)) + + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, z_log_sigma) + loss = losses['loss'] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for loss_kind in losses: + logger.add_scalar(f'train/{loss_kind}', losses[loss_kind].cpu().detach(), epoch*len(train_data_loader)+sample_num) + + val_losses = defaultdict(list) + for img, target in test_data_loader: + recovered_img, z_mu, z_log_sigma = model(img.to(device)) + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, z_log_sigma) + + for loss_kind in losses: + val_losses[loss_kind].append(losses[loss_kind].cpu().detach()) + + for loss_kind in val_losses: + logger.add_scalar(f'val/{loss_kind}', np.mean(val_losses[loss_kind]), epoch) + logger.add_image(f'val/example_image', img.cpu().detach()[0], epoch) + logger.add_image(f'val/example_image_rec', recovered_img.cpu().detach()[0], epoch) From aeb052da4eda2e3ed74b331134217efa5fc8ec28 Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 2 Nov 2022 17:53:28 +0000 Subject: [PATCH 004/106] Added dm_control integration --- config/agent/dqn_agent.yaml | 2 +- config/config.yaml | 13 +++-- main.py | 72 ++++++++++++++++++++------ pyproject.toml | 4 ++ rl_sandbox/agents/dqn_agent.py | 13 ++--- rl_sandbox/agents/random_agent.py | 30 +++++++++++ rl_sandbox/metrics.py | 17 ++++++ rl_sandbox/utils/dm_control.py | 50 ++++++++++++++++++ rl_sandbox/utils/replay_buffer.py | 7 +-- rl_sandbox/utils/rollout_generation.py | 47 ++++++++++++----- 10 files changed, 211 insertions(+), 44 deletions(-) create mode 100644 rl_sandbox/agents/random_agent.py create mode 100644 rl_sandbox/metrics.py create mode 100644 rl_sandbox/utils/dm_control.py diff --git a/config/agent/dqn_agent.yaml b/config/agent/dqn_agent.yaml index f8097cd..b95a589 100644 --- a/config/agent/dqn_agent.yaml +++ b/config/agent/dqn_agent.yaml @@ -1,4 +1,4 @@ name: dqn hidden_layer_size: 16 num_layers: 1 -discount_factor: 0.98 +discount_factor: 0.999 diff --git a/config/config.yaml b/config/config.yaml index c3fc383..76962a7 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,14 +1,21 @@ defaults: - agent/dqn_agent + - _self_ + +env: + type: dm_control + domain_name: cartpole + task_name: swingup + # type: gym + # name: CartPole-v1 -env: CartPole-v1 seed: 42 -device_type: cpu +device_type: mps training: epochs: 5000 batch_size: 128 - + validation: rollout_num: 5 visualize: true diff --git a/main.py b/main.py index 570b856..29c4b38 100644 --- a/main.py +++ b/main.py @@ -1,31 +1,55 @@ +import gym import hydra +import numpy as np +from dm_control import suite from omegaconf import DictConfig, OmegaConf +from torch.utils.tensorboard.writer import SummaryWriter +from tqdm import tqdm from rl_sandbox.agents.dqn_agent import DqnAgent +from rl_sandbox.agents.random_agent import RandomAgent +from rl_sandbox.metrics import MetricsEvaluator +from rl_sandbox.utils.dm_control import ActionDiscritizer, decode_dm_ts from rl_sandbox.utils.replay_buffer import ReplayBuffer -from rl_sandbox.utils.rollout_generation import collect_rollout, fillup_replay_buffer, collect_rollout_num +from rl_sandbox.utils.rollout_generation import (collect_rollout, + collect_rollout_num, + fillup_replay_buffer) from rl_sandbox.utils.schedulers import LinearScheduler -from torch.utils.tensorboard.writer import SummaryWriter -import numpy as np - -import gym @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): print(OmegaConf.to_yaml(cfg)) - env = gym.make(cfg.env) - visualized_env = gym.make(cfg.env, render_mode='rgb_array_list') + match cfg.env.type: + case "dm_control": + env = suite.load(domain_name=cfg.env.domain_name, + task_name=cfg.env.task_name) + visualized_env = env + case "gym": + env = gym.make(cfg.env) + visualized_env = gym.make(cfg.env, render_mode='rgb_array_list') + case _: + raise RuntimeError("Invalid environment type") buff = ReplayBuffer() fillup_replay_buffer(env, buff, cfg.training.batch_size) - # INFO: currently supports only discrete action space agent_params = {**cfg.agent} agent_name = agent_params.pop('name') - agent = DqnAgent(obs_space_num=env.observation_space.shape[0], - actions_num=env.action_space.n, + action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=10) + metrics_evaluator = MetricsEvaluator() + + match cfg.env.type: + case "dm_control": + obs_space_num = sum([v.shape[0] for v in env.observation_spec().values()]) + case "gym": + obs_space_num = env.observation_space.shape[0] + + exploration_agent = RandomAgent(env) + agent = DqnAgent(obs_space_num=obs_space_num, + actions_num=action_disritizer.shape, + # actions_num=env.action_space.n, device_type=cfg.device_type, **agent_params, ) @@ -35,19 +59,34 @@ def main(cfg: DictConfig): scheduler = LinearScheduler(0.9, 0.01, 5_000) global_step = 0 - for epoch_num in range(cfg.training.epochs): + for epoch_num in tqdm(range(cfg.training.epochs)): ### Training and exploration - state, _ = env.reset() + + match cfg.env.type: + case "dm_control": + state, _, _ = decode_dm_ts(env.reset()) + case "gym": + state, _ = env.reset() + terminated = False while not terminated: if np.random.random() > scheduler.step(): - action = env.action_space.sample() + action = exploration_agent.get_action(state) + action = action_disritizer.discretize(action) else: action = agent.get_action(state) - new_state, reward, terminated, _, _ = env.step(action) + + match cfg.env.type: + case "dm_control": + new_state, reward, terminated = decode_dm_ts(env.step(action_disritizer.undiscretize(action))) + case "gym": + new_state, reward, terminated, _, _ = env.step(action) + action = action_disritizer.undiscretize(action) + buff.add_sample(state, action, reward, new_state, terminated) s, a, r, n, f = buff.sample(cfg.training.batch_size) + a = np.stack([action_disritizer.discretize(a_) for a_ in a]).reshape(-1, 1) loss = agent.train(s, a, r, n, f) writer.add_scalar('train/loss', loss, global_step) @@ -56,8 +95,9 @@ def main(cfg: DictConfig): ### Validation if epoch_num % 100 == 0: rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) - average_len = np.mean(list(map(lambda x: len(x.states), rollouts))) - writer.add_scalar('val/average_len', average_len, epoch_num) + metrics = metrics_evaluator.calculate_metrics(rollouts) + for metric_name, metric in metrics.items(): + writer.add_scalar(f'val/{metric_name}', metric, epoch_num) if cfg.validation.visualize: rollouts = collect_rollout_num(visualized_env, 1, agent, save_obs=True) diff --git a/pyproject.toml b/pyproject.toml index e75a807..8e6f95e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,3 +16,7 @@ nptyping = '*' gym = "^0.26.1" pygame = '*' moviepy = '*' +torchvision = '^0.13' +torch = '^1.12' +tensorboard = '^2.0' +dm-control = '^1.0.0' diff --git a/rl_sandbox/agents/dqn_agent.py b/rl_sandbox/agents/dqn_agent.py index 9ce2887..5ca9368 100644 --- a/rl_sandbox/agents/dqn_agent.py +++ b/rl_sandbox/agents/dqn_agent.py @@ -21,19 +21,20 @@ def __init__(self, actions_num: int, num_layers).to(device_type) self.optimizer = torch.optim.Adam(self.value_func.parameters(), lr=1e-3) self.loss = torch.nn.MSELoss() + self.device_type = device_type def get_action(self, obs: State) -> Action: - return np.array(torch.argmax(self.value_func(torch.from_numpy(obs.reshape(1, -1)).to(device_type)), dim=1).detach().cpu())[0] + return np.array(torch.argmax(self.value_func(torch.from_numpy(obs.reshape(1, -1)).to(self.device_type)), dim=1).detach().cpu())[0] def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): # Bellman error: MSE( (r + gamma * max_a Q(S_t+1, a)) - Q(s_t, a) ) # check for is finished - s = torch.from_numpy(s).to(device_type) - a = torch.from_numpy(a).to(device_type) - r = torch.from_numpy(r).to(device_type) - next = torch.from_numpy(next).to(device_type) - is_finished = torch.from_numpy(is_finished).to(device_type) + s = torch.from_numpy(s).to(self.device_type) + a = torch.from_numpy(a).to(self.device_type) + r = torch.from_numpy(r).to(self.device_type) + next = torch.from_numpy(next).to(self.device_type) + is_finished = torch.from_numpy(is_finished).to(self.device_type) # TODO: normalize input # TODO: double dqn with target network diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py new file mode 100644 index 0000000..d06e031 --- /dev/null +++ b/rl_sandbox/agents/random_agent.py @@ -0,0 +1,30 @@ +import gym +import numpy as np +from dm_control.composer.environment import Environment as dmEnv +from nptyping import Float, NDArray, Shape + +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.dm_control import ActionDiscritizer +from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, + States, TerminationFlags) + + +class RandomAgent(RlAgent): + def __init__(self, env: gym.Env | dmEnv): + self.action_space = None + self.action_spec = None + if isinstance(env, gym.Env): + self.action_space = env.action_space + else: + self.action_spec = env.action_spec() + + def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: + if self.action_space is not None: + return self.action_space.sample() + else: + return np.random.uniform(self.action_spec.minimum, + self.action_spec.maximum, + size=self.action_spec.shape) + + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): + pass diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py new file mode 100644 index 0000000..07a20a9 --- /dev/null +++ b/rl_sandbox/metrics.py @@ -0,0 +1,17 @@ +import numpy as np + +from rl_sandbox.utils.replay_buffer import Rollout + + +class MetricsEvaluator(): + def calculate_metrics(self, rollouts: list[Rollout]): + return { + 'episode_len': self._episode_duration(rollouts), + 'episode_return': self._episode_return(rollouts) + } + + def _episode_duration(self, rollouts: list[Rollout]): + return np.mean(list(map(lambda x: len(x.states), rollouts))) + + def _episode_return(self, rollouts: list[Rollout]): + return np.mean(list(map(lambda x: sum(x.rewards), rollouts))) diff --git a/rl_sandbox/utils/dm_control.py b/rl_sandbox/utils/dm_control.py new file mode 100644 index 0000000..f31e64d --- /dev/null +++ b/rl_sandbox/utils/dm_control.py @@ -0,0 +1,50 @@ +import numpy as np +from dm_env import specs +from nptyping import Float, Int, NDArray, Shape + + +# TODO: add tests +class ActionDiscritizer: + def __init__(self, action_spec: specs.BoundedArray, values_per_dim: int): + self.actions_dim = action_spec.shape[0] + self.min = action_spec.minimum + self.max = action_spec.maximum + self.per_dim = values_per_dim + self.shape = self.per_dim**self.actions_dim + + # actions_dim X per_dim + self.grid = np.stack([np.linspace(min, max, self.per_dim, endpoint=True) for min, max in zip(self.min, self.max)]) + + def discretize(self, action: NDArray[Shape['*'], Float]) -> NDArray[Shape['*'], Int]: + ks = np.argmin((self.grid - np.ones((self.per_dim, 1)).dot(action).T)**2, axis=1) + a = 0 + for i, k in enumerate(ks): + a += k*self.per_dim**i + # ret_a = np.zeros(self.shape, dtype=np.int64) + # ret_a[a] = 1 + # return ret_a + return a + + def undiscretize(self, action: NDArray[Shape['*'], Int]) -> NDArray[Shape['*'], Float]: + ks = [] + # k = np.argmax(action) + k = action + for i in range(self.per_dim - 1, -1, -1): + ks.append(k // self.per_dim**i) + k -= ks[-1] * self.per_dim**i + + a = [] + for k, vals in zip(reversed(ks), self.grid): + a.append(vals[k]) + return np.array(a) + +def decode_dm_ts(time_step): + state = time_step.observation + state = np.concatenate([state[s] for s in state], dtype=np.float32) + reward = time_step.reward + terminated = time_step.last() + # if time_step.discount is not None: + # terminated = not time_step.discount + # else: + # terminated = False + return state, reward, terminated diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 8821301..3835b0b 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -1,10 +1,9 @@ -import random import typing as t from collections import deque from dataclasses import dataclass import numpy as np -from nptyping import Bool, Int, Float, NDArray, Shape +from nptyping import Bool, Float, Int, NDArray, Shape Observation = NDArray[Shape["*,*,3"],Int] State = NDArray[Shape["*"],Float] @@ -63,7 +62,5 @@ def can_sample(self, num: int): return len(self.states) >= num def sample(self, num: int) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: - indeces = list(range(len(self.states))) - random.shuffle(indeces) - indeces = indeces[:num] + indeces = np.random.choice(len(self.states), num) return self.states[indeces], self.actions[indeces], self.rewards[indeces], self.next_states[indeces], self.is_finished[indeces] diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index c6f65a6..8f19064 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -2,30 +2,51 @@ import gym import numpy as np +from dm_control.composer.environment import Environment as dmEnv -from rl_sandbox.utils.replay_buffer import (ReplayBuffer, Rollout) +from rl_sandbox.agents.random_agent import RandomAgent +from rl_sandbox.utils.dm_control import ActionDiscritizer, decode_dm_ts +from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout -def collect_rollout(env: gym.Env, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> Rollout: - s, a, r, n, f = [], [], [], [], [] - state, _ = env.reset() - terminated = False +def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> Rollout: + s, a, r, n, f, o = [], [], [], [], [], [] + + match env: + case gym.Env(): + state, _ = env.reset() + case dmEnv(): + state, _, terminated = decode_dm_ts(env.reset()) + + if agent is None: + agent = RandomAgent(env) while not terminated: - # TODO: use gym.ActionWrapper instead - if agent is None: - action = env.action_space.sample() - else: - action = agent.get_action(state) - new_state, reward, terminated, _, _ = env.step(action) + action = agent.get_action(state) + + match env: + case gym.Env(): + new_state, reward, terminated, _, _ = env.step(action) + case dmEnv(): + new_state, reward, terminated = decode_dm_ts(env.step(action)) + s.append(state) - a.append(action) + # FIXME: action discritezer should be defined once + action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=10) + a.append(action_disritizer.discretize(action)) r.append(reward) n.append(new_state) f.append(terminated) + + if save_obs and isinstance(env, dmEnv): + o.append(env.physics.render(128, 128, camera_id=0)) state = new_state - obs = np.stack(list(env.render())) if save_obs else None + match env: + case gym.Env(): + obs = np.stack(list(env.render())) if save_obs else None + case dmEnv(): + obs = np.array(o) if save_obs else None return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> t.List[Rollout]: From 21c179ff283bd8dcc7de828f4d9ffcf689de2494 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 4 Nov 2022 16:23:24 +0000 Subject: [PATCH 005/106] Added VQ-VAE, tested on CFAR10 --- rl_sandbox/vision/vae.py | 160 +++++++++++++++++++++------------- rl_sandbox/vision/vq_vae.py | 167 ++++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+), 60 deletions(-) create mode 100644 rl_sandbox/vision/vq_vae.py diff --git a/rl_sandbox/vision/vae.py b/rl_sandbox/vision/vae.py index 3dc00f2..8a8d133 100644 --- a/rl_sandbox/vision/vae.py +++ b/rl_sandbox/vision/vae.py @@ -10,50 +10,72 @@ from tqdm import tqdm -class VAE(nn.Module): - def __init__(self, latent_dim=2, kl_weight=2.5e-4): +class ResBlock(nn.Module): + + def __init__(self, in_channels, hidden_units=256): super().__init__() - self.latent_dim = latent_dim - self.kl_weight = kl_weight - self.encoder = nn.Sequential( - nn.BatchNorm2d(1), - nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1), # 1x28x28 -> 8x13x13 - nn.LeakyReLU(inplace=True), - nn.BatchNorm2d(8), + self.block = nn.Sequential( + nn.ReLU(), nn.Conv2d(in_channels, hidden_units, kernel_size=3, + padding='same'), nn.ReLU(inplace=True), + nn.Conv2d(hidden_units, in_channels, kernel_size=1, padding='same')) - nn.Conv2d(8, 32, kernel_size=3, stride=2), # 8x13x13 -> 32x6x6 - nn.LeakyReLU(inplace=True), - nn.BatchNorm2d(32), + def forward(self, X): + output = self.block(X) + return X + output - nn.Conv2d(32, 8, kernel_size=1), # 32x6x6 -> 8x6x6 - nn.LeakyReLU(inplace=True), - nn.BatchNorm2d(8), - nn.Flatten(), # 8x6x6 -> 36*8 - ) +class VAE(nn.Module): - self.f_mu = nn.Linear(288, self.latent_dim) - self.f_log_sigma = nn.Linear(288, self.latent_dim) + def __init__(self, latent_dim=3, kl_weight=2.5e-4): + super().__init__() + self.latent_dim = latent_dim + self.kl_weight = kl_weight - self.decoder_1 = nn.Sequential( - nn.Linear(self.latent_dim, 288), - nn.LeakyReLU(inplace=True), - ) + in_channels = 3 + out_channels = 128 - self.decoder_2 = nn.Sequential( - nn.BatchNorm2d(8), - nn.ConvTranspose2d(8, 32, kernel_size=1), - nn.LeakyReLU(inplace=True), + self.encoder = nn.Sequential( + nn.BatchNorm2d(3), + nn.Conv2d(in_channels, out_channels // 2, kernel_size=4, stride=2, + padding=1), # 32 -> 16 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.Conv2d(out_channels // 2, out_channels, kernel_size=4, stride=2, + padding=1), # 16 -> 8 + nn.LeakyReLU(inplace=True), + ResBlock(out_channels), + ResBlock(out_channels), + nn.Conv2d(out_channels, 4, 1), # 4x8x8 + nn.Flatten()) + + self.f_mu = nn.Linear(256, self.latent_dim) + self.f_log_sigma = nn.Linear(256, self.latent_dim) - nn.BatchNorm2d(32), - nn.ConvTranspose2d(32, 8, kernel_size=3, stride=2), - nn.LeakyReLU(inplace=True), + self.decoder_1 = nn.Sequential( + nn.Linear(self.latent_dim, 256), + nn.LeakyReLU(inplace=True), + ) - nn.BatchNorm2d(8), - nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, output_padding=1), - nn.LeakyReLU(inplace=True), - ) + self.decoder_2 = nn.Sequential( + nn.Conv2d(4, out_channels, 1), + ResBlock(out_channels), + ResBlock(out_channels), + nn.BatchNorm2d(out_channels), + nn.ConvTranspose2d(out_channels, + out_channels // 2, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.ConvTranspose2d(out_channels // 2, + in_channels, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + ) def forward(self, X): z_h = self.encoder(X) @@ -61,10 +83,11 @@ def forward(self, X): z_mu = self.f_mu(z_h) z_log_sigma = self.f_log_sigma(z_h) - z = z_mu + z_log_sigma.exp()*torch.rand_like(z_mu).to('mps') + device = next(self.f_mu.parameters()).device + z = z_mu + z_log_sigma.exp() * torch.rand_like(z_mu).to(device) x_h_1 = self.decoder_1(z) - x_h = self.decoder_2(x_h_1.view(-1, 8, 6, 6)) + x_h = self.decoder_2(x_h_1.view(-1, 4, 8, 8)) return x_h, z_mu, z_log_sigma def calculate_loss(self, x, x_h, z_mu, z_log_sigma) -> dict[str, torch.Tensor]: @@ -72,39 +95,52 @@ def calculate_loss(self, x, x_h, z_mu, z_log_sigma) -> dict[str, torch.Tensor]: # p(z) = N(0, 1) L_rec = torch.nn.MSELoss() - loss_kl = -1 * torch.mean(torch.sum(z_log_sigma + 0.5*(1 - z_log_sigma.exp()**2 - z_mu**2), dim=1), dim=0) + loss_kl = -1 * torch.mean(torch.sum( + z_log_sigma + 0.5 * (1 - z_log_sigma.exp()**2 - z_mu**2), dim=1), + dim=0) loss_rec = L_rec(x, x_h) - return {'loss': loss_rec + self.kl_weight * loss_kl, 'loss_rec': loss_rec, 'loss_kl': loss_kl} + return { + 'loss': loss_rec + self.kl_weight * loss_kl, + 'loss_rec': loss_rec, + 'loss_kl': loss_kl + } + def image_preprocessing(img: Image): return torchvision.transforms.ToTensor()(img) + if __name__ == "__main__": - train_mnist_data = torchvision.datasets.MNIST(str(Path()/'data'/'mnist'), - download=True, - train=True, - transform=image_preprocessing) - test_mnist_data = torchvision.datasets.MNIST(str(Path()/'data'/'mnist'), - download=True, - train=False, - transform=image_preprocessing) - train_data_loader = torch.utils.data.DataLoader(train_mnist_data, - batch_size=128, - shuffle=True, - num_workers=8) - test_data_loader = torch.utils.data.DataLoader(test_mnist_data, - batch_size=128, - shuffle=True, - num_workers=8) + import torch.multiprocessing + + # fix for "unable to open shared memory on mac" + torch.multiprocessing.set_sharing_strategy('file_system') + + train_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=True, + transform=image_preprocessing) + test_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=False, + transform=image_preprocessing) + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=128, + shuffle=True, + num_workers=8) + test_data_loader = torch.utils.data.DataLoader(test_data, + batch_size=128, + shuffle=True, + num_workers=8) import socket from datetime import datetime current_time = datetime.now().strftime("%b%d_%H-%M-%S") logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}") device = 'mps' - model = VAE().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + model = VAE(latent_dim=256).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) for epoch in tqdm(range(100)): @@ -113,7 +149,8 @@ def image_preprocessing(img: Image): for sample_num, (img, target) in enumerate(train_data_loader): recovered_img, z_mu, z_log_sigma = model(img.to(device)) - losses = model.calculate_loss(img.to(device), recovered_img, z_mu, z_log_sigma) + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, + z_log_sigma) loss = losses['loss'] optimizer.zero_grad() @@ -121,12 +158,14 @@ def image_preprocessing(img: Image): optimizer.step() for loss_kind in losses: - logger.add_scalar(f'train/{loss_kind}', losses[loss_kind].cpu().detach(), epoch*len(train_data_loader)+sample_num) + logger.add_scalar(f'train/{loss_kind}', losses[loss_kind].cpu().detach(), + epoch * len(train_data_loader) + sample_num) val_losses = defaultdict(list) for img, target in test_data_loader: recovered_img, z_mu, z_log_sigma = model(img.to(device)) - losses = model.calculate_loss(img.to(device), recovered_img, z_mu, z_log_sigma) + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, + z_log_sigma) for loss_kind in losses: val_losses[loss_kind].append(losses[loss_kind].cpu().detach()) @@ -134,4 +173,5 @@ def image_preprocessing(img: Image): for loss_kind in val_losses: logger.add_scalar(f'val/{loss_kind}', np.mean(val_losses[loss_kind]), epoch) logger.add_image(f'val/example_image', img.cpu().detach()[0], epoch) - logger.add_image(f'val/example_image_rec', recovered_img.cpu().detach()[0], epoch) + logger.add_image(f'val/example_image_rec', + recovered_img.cpu().detach()[0], epoch) diff --git a/rl_sandbox/vision/vq_vae.py b/rl_sandbox/vision/vq_vae.py new file mode 100644 index 0000000..9a2552c --- /dev/null +++ b/rl_sandbox/vision/vq_vae.py @@ -0,0 +1,167 @@ +from collections import defaultdict +from pathlib import Path + +import numpy as np +import torch +import torchvision +from PIL.Image import Image +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from rl_sandbox.vision.vae import ResBlock + + +class VQ_VAE(nn.Module): + + def __init__(self, latent_space_size, latent_dim, beta=0.25): + super().__init__() + # amount of the discrete vectors + self.latent_space_size = latent_space_size + # dimensionality of each category + self.latent_dim = latent_dim + self.beta = beta + + self.latent_space = torch.nn.Parameter( + torch.empty(size=(self.latent_space_size, self.latent_dim))) + torch.nn.init.kaiming_uniform_(self.latent_space) + + in_channels = 3 + out_channels = 128 + + self.encoder = nn.Sequential( + nn.BatchNorm2d(3), + nn.Conv2d(in_channels, out_channels // 2, kernel_size=4, stride=2, + padding=1), # 32 -> 16 + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.Conv2d(out_channels // 2, out_channels, kernel_size=4, stride=2, + padding=1), # 16 -> 8 + nn.LeakyReLU(inplace=True), + ResBlock(out_channels), + ResBlock(out_channels), + nn.Conv2d(out_channels, latent_dim, 1), # Dx8x8 + ) + + self.decoder = nn.Sequential( + nn.Conv2d(latent_dim, out_channels, 1), + ResBlock(out_channels), + ResBlock(out_channels), + nn.BatchNorm2d(out_channels), + nn.ConvTranspose2d(out_channels, + out_channels // 2, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(out_channels // 2), + nn.ConvTranspose2d(out_channels // 2, + in_channels, + kernel_size=4, + stride=2, + padding=1), + nn.LeakyReLU(inplace=True), + ) + + def quantize(self, z): + # z <- BxDxHxW + # Pytorch BUG: https://github.com/pytorch/pytorch/issues/84206 + # .to(memory_format=torch.contiguous_format) should be used instead of .contigious() on mac m1 + latents = torch.permute(z, (0, 2, 3, 1)).to(memory_format=torch.contiguous_format) # BxHxWxD + flatten = latents.view(-1, self.latent_dim) # BHWxD + + # use the property that (a - b)^2 = a^2 - 2ab + b^2 + l2_dist = torch.sum(flatten**2, dim=1, keepdim=True) - 2 * ( + flatten @ self.latent_space.T) + torch.sum(self.latent_space**2, dim=1) # BHWxK + + ks = torch.argmin(l2_dist, dim=1) + + flatten_quantized_latents = torch.index_select(self.latent_space, 0, ks) # BHWxD + e = flatten_quantized_latents.view(latents.shape).permute((0, 3, 1, 2)).to(memory_format=torch.contiguous_format) + z.retain_grad() + e.grad = z.grad + return e + + + def forward(self, X): + z = self.encoder(X) + e = self.quantize(z) + x_h = self.decoder(e) + return x_h, z, e + + def calculate_loss(self, x, x_h, z, e) -> dict[str, torch.Tensor]: + # loss = log p(x | z) + || stop_grad(e) - z ||_2 + beta *|| e - stop_grad(z) ||_2 + L_rec = torch.nn.MSELoss() + + loss_reg = torch.norm(e.detach() - z, + p=2) + self.beta * torch.norm(e - z.detach(), p=2) + loss_rec = L_rec(x, x_h) + + return {'loss': loss_rec + loss_reg, 'loss_rec': loss_rec, 'loss_reg': loss_reg} + + +def image_preprocessing(img: Image): + return torchvision.transforms.ToTensor()(img) + + +if __name__ == "__main__": + # fix for "unable to open shared memory on mac" + torch.multiprocessing.set_sharing_strategy('file_system') + + train_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=True, + transform=image_preprocessing) + test_data = torchvision.datasets.CIFAR10(str(Path() / 'data' / 'cifar10'), + download=True, + train=False, + transform=image_preprocessing) + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=128, + shuffle=True, + num_workers=8) + test_data_loader = torch.utils.data.DataLoader(test_data, + batch_size=128, + shuffle=True, + num_workers=8) + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}") + + device = 'mps' + model = VQ_VAE(latent_space_size=256, latent_dim=1).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) + + for epoch in tqdm(range(100)): + + logger.add_scalar('epoch', epoch, epoch) + + for sample_num, (img, target) in enumerate(train_data_loader): + recovered_img, z, e = model(img.to(device)) + + losses = model.calculate_loss(img.to(device), recovered_img, z, e) + loss = losses['loss'] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + for loss_kind in losses: + logger.add_scalar(f'train/{loss_kind}', losses[loss_kind].cpu().detach(), + epoch * len(train_data_loader) + sample_num) + + val_losses = defaultdict(list) + for img, target in test_data_loader: + recovered_img, z_mu, z_log_sigma = model(img.to(device)) + losses = model.calculate_loss(img.to(device), recovered_img, z_mu, + z_log_sigma) + + for loss_kind in losses: + val_losses[loss_kind].append(losses[loss_kind].cpu().detach()) + + for loss_kind in val_losses: + logger.add_scalar(f'val/{loss_kind}', np.mean(val_losses[loss_kind]), epoch) + logger.add_image(f'val/example_image', img.cpu().detach()[0], epoch) + logger.add_image(f'val/example_image_rec', + recovered_img.cpu().detach()[0], epoch) From 0fee6589ca2ebb45cf414f08eb63de91d0a21ce7 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 5 Nov 2022 00:17:37 +0000 Subject: [PATCH 006/106] Added clusterized sampling, and sampling of observations instead of states --- config/agent/{dqn_agent.yaml => dqn.yaml} | 1 + config/agent/dreamer_v2.yaml | 7 +++ config/config.yaml | 6 +- main.py | 38 +++++++------ rl_sandbox/agents/__init__.py | 2 + rl_sandbox/agents/{dqn_agent.py => dqn.py} | 3 +- rl_sandbox/utils/fc_nn.py | 18 +++--- rl_sandbox/utils/replay_buffer.py | 57 +++++++++++++------ rl_sandbox/utils/rollout_generation.py | 21 +++---- tests/test_replay_buffer.py | 65 ++++++++++++++++------ 10 files changed, 145 insertions(+), 73 deletions(-) rename config/agent/{dqn_agent.yaml => dqn.yaml} (65%) create mode 100644 config/agent/dreamer_v2.yaml create mode 100644 rl_sandbox/agents/__init__.py rename rl_sandbox/agents/{dqn_agent.py => dqn.py} (94%) diff --git a/config/agent/dqn_agent.yaml b/config/agent/dqn.yaml similarity index 65% rename from config/agent/dqn_agent.yaml rename to config/agent/dqn.yaml index b95a589..dfea883 100644 --- a/config/agent/dqn_agent.yaml +++ b/config/agent/dqn.yaml @@ -1,4 +1,5 @@ name: dqn +_target_: rl_sandbox.agents.DqnAgent hidden_layer_size: 16 num_layers: 1 discount_factor: 0.999 diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml new file mode 100644 index 0000000..ca2a661 --- /dev/null +++ b/config/agent/dreamer_v2.yaml @@ -0,0 +1,7 @@ +_target_: rl_sandbox.agents.DreamerV2 +discount_factor: 0.995 +batch_cluster_size: 8 +latent_dim: 32 +latent_classes: 32 +rssm_dim: 600 +kl_loss_scale: 0.1 diff --git a/config/config.yaml b/config/config.yaml index 76962a7..d748bab 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -3,11 +3,13 @@ defaults: - _self_ env: + # type: gym + # name: CartPole-v1 type: dm_control domain_name: cartpole task_name: swingup - # type: gym - # name: CartPole-v1 + run_on_pixels: true + obs_res: [128, 128] seed: 42 device_type: mps diff --git a/main.py b/main.py index 29c4b38..8211034 100644 --- a/main.py +++ b/main.py @@ -6,20 +6,18 @@ from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm -from rl_sandbox.agents.dqn_agent import DqnAgent from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.dm_control import ActionDiscritizer, decode_dm_ts from rl_sandbox.utils.replay_buffer import ReplayBuffer -from rl_sandbox.utils.rollout_generation import (collect_rollout, - collect_rollout_num, +from rl_sandbox.utils.rollout_generation import (collect_rollout_num, fillup_replay_buffer) from rl_sandbox.utils.schedulers import LinearScheduler @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): - print(OmegaConf.to_yaml(cfg)) + # print(OmegaConf.to_yaml(cfg)) match cfg.env.type: case "dm_control": @@ -29,30 +27,31 @@ def main(cfg: DictConfig): case "gym": env = gym.make(cfg.env) visualized_env = gym.make(cfg.env, render_mode='rgb_array_list') + if cfg.env.run_on_pixels: + raise NotImplementedError("Run on pixels supported only for 'dm_control'") case _: raise RuntimeError("Invalid environment type") buff = ReplayBuffer() - fillup_replay_buffer(env, buff, cfg.training.batch_size) + obs_res = cfg.env.obs_res if cfg.env.run_on_pixels else None + fillup_replay_buffer(env, buff, cfg.training.batch_size, obs_res=obs_res) - agent_params = {**cfg.agent} - agent_name = agent_params.pop('name') action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=10) metrics_evaluator = MetricsEvaluator() match cfg.env.type: case "dm_control": obs_space_num = sum([v.shape[0] for v in env.observation_spec().values()]) + if cfg.env.run_on_pixels: + obs_space_num = (*cfg.env.obs_res, 3) case "gym": obs_space_num = env.observation_space.shape[0] exploration_agent = RandomAgent(env) - agent = DqnAgent(obs_space_num=obs_space_num, - actions_num=action_disritizer.shape, - # actions_num=env.action_space.n, - device_type=cfg.device_type, - **agent_params, - ) + agent = hydra.utils.instantiate(cfg.agent, + obs_space_num=obs_space_num, + actions_num=action_disritizer.shape, + device_type=cfg.device_type) writer = SummaryWriter() @@ -65,6 +64,7 @@ def main(cfg: DictConfig): match cfg.env.type: case "dm_control": state, _, _ = decode_dm_ts(env.reset()) + obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None case "gym": state, _ = env.reset() @@ -79,14 +79,18 @@ def main(cfg: DictConfig): match cfg.env.type: case "dm_control": new_state, reward, terminated = decode_dm_ts(env.step(action_disritizer.undiscretize(action))) + # FIXME: if run_on_pixels next_state should also be observation + obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None case "gym": new_state, reward, terminated, _, _ = env.step(action) action = action_disritizer.undiscretize(action) + obs = None - buff.add_sample(state, action, reward, new_state, terminated) + buff.add_sample(state, action, reward, new_state, terminated, obs) - s, a, r, n, f = buff.sample(cfg.training.batch_size) - a = np.stack([action_disritizer.discretize(a_) for a_ in a]).reshape(-1, 1) + s, a, r, n, f = buff.sample(cfg.training.batch_size, + return_observation=cfg.env.run_on_pixels, + cluster_size=cfg.agent.get('batch_cluster_size', 1)) loss = agent.train(s, a, r, n, f) writer.add_scalar('train/loss', loss, global_step) @@ -100,7 +104,7 @@ def main(cfg: DictConfig): writer.add_scalar(f'val/{metric_name}', metric, epoch_num) if cfg.validation.visualize: - rollouts = collect_rollout_num(visualized_env, 1, agent, save_obs=True) + rollouts = collect_rollout_num(visualized_env, 1, agent, obs_res=cfg.obs_res) for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) diff --git a/rl_sandbox/agents/__init__.py b/rl_sandbox/agents/__init__.py new file mode 100644 index 0000000..d1b779f --- /dev/null +++ b/rl_sandbox/agents/__init__.py @@ -0,0 +1,2 @@ +from rl_sandbox.agents.dqn import DqnAgent +from rl_sandbox.agents.dreamer_v2 import DreamerV2 diff --git a/rl_sandbox/agents/dqn_agent.py b/rl_sandbox/agents/dqn.py similarity index 94% rename from rl_sandbox/agents/dqn_agent.py rename to rl_sandbox/agents/dqn.py index 5ca9368..9c7b01f 100644 --- a/rl_sandbox/agents/dqn_agent.py +++ b/rl_sandbox/agents/dqn.py @@ -18,7 +18,8 @@ def __init__(self, actions_num: int, self.value_func = fc_nn_generator(obs_space_num, actions_num, hidden_layer_size, - num_layers).to(device_type) + num_layers, + torch.nn.ReLU).to(device_type) self.optimizer = torch.optim.Adam(self.value_func.parameters(), lr=1e-3) self.loss = torch.nn.MSELoss() self.device_type = device_type diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index 85e2f09..a66d082 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -1,15 +1,17 @@ +import typing as t from torch import nn -def fc_nn_generator(obs_space_num: int, - action_space_num: int, - hidden_layer_size: int, - num_layers: int): +def fc_nn_generator(input_num: int, + output_num: int, + hidden_size: int, + num_layers: int, + final_activation: t.Type[nn.Module] = nn.Identity): layers = [] - layers.append(nn.Linear(obs_space_num, hidden_layer_size)) + layers.append(nn.Linear(input_num, hidden_size)) layers.append(nn.ReLU(inplace=True)) for _ in range(num_layers): - layers.append(nn.Linear(hidden_layer_size, hidden_layer_size)) + layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(nn.ReLU(inplace=True)) - layers.append(nn.Linear(hidden_layer_size, action_space_num)) - layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(hidden_size, output_num)) + layers.append(final_activation()) return nn.Sequential(*layers) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 3835b0b..1d56008 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -5,15 +5,16 @@ import numpy as np from nptyping import Bool, Float, Int, NDArray, Shape -Observation = NDArray[Shape["*,*,3"],Int] -State = NDArray[Shape["*"],Float] -Action = NDArray[Shape["*"],Int] +Observation = NDArray[Shape["*,*,3"], Int] +State = NDArray[Shape["*"], Float] +Action = NDArray[Shape["*"], Int] + +Observations = NDArray[Shape["*,*,*,3"], Int] +States = NDArray[Shape["*,*"], Float] +Actions = NDArray[Shape["*,*"], Int] +Rewards = NDArray[Shape["*"], Float] +TerminationFlags = NDArray[Shape["*"], Bool] -Observations = NDArray[Shape["*,*,*,3"],Int] -States = NDArray[Shape["*,*"],Float] -Actions = NDArray[Shape["*,*"],Int] -Rewards = NDArray[Shape["*"],Float] -TerminationFlags = NDArray[Shape["*"],Bool] @dataclass class Rollout: @@ -26,12 +27,14 @@ class Rollout: class ReplayBuffer: + def __init__(self, max_len=2_000): self.max_len = max_len self.states: States = np.array([]) self.actions: Actions = np.array([]) self.rewards: Rewards = np.array([]) self.next_states: States = np.array([]) + self.observations: t.Optional[Observations] def add_rollout(self, rollout: Rollout): if len(self.states) == 0: @@ -40,27 +43,45 @@ def add_rollout(self, rollout: Rollout): self.rewards = rollout.rewards self.next_states = rollout.next_states self.is_finished = rollout.is_finished + self.observations = rollout.observations else: self.states = np.concatenate([self.states, rollout.states]) self.actions = np.concatenate([self.actions, rollout.actions]) self.rewards = np.concatenate([self.rewards, rollout.rewards]) self.next_states = np.concatenate([self.next_states, rollout.next_states]) self.is_finished = np.concatenate([self.is_finished, rollout.is_finished]) + if self.observations is not None: + self.observations = np.concatenate( + [self.observations, rollout.observations]) if len(self.states) >= self.max_len: - self.states = self.states - self.actions = self.actions - self.rewards = self.rewards - self.next_states = self.next_states - self.is_finished = self.is_finished + self.states = self.states[:self.max_len] + self.actions = self.actions[:self.max_len] + self.rewards = self.rewards[:self.max_len] + self.next_states = self.next_states[:self.max_len] + self.is_finished = self.is_finished[:self.max_len] + if self.observations is not None: + self.observations = self.observations[:self.max_len] - def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): - rollout = Rollout(np.array([s]), np.expand_dims(np.array([a]), 0), np.array([r], dtype=np.float32), np.array([n]), np.array([f])) + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool, + o: t.Optional[Observation]): + rollout = Rollout(np.array([s]), np.expand_dims(np.array([a]), 0), + np.array([r], dtype=np.float32), np.array([n]), np.array([f]), + np.array([o]) if o is not None else None) self.add_rollout(rollout) def can_sample(self, num: int): return len(self.states) >= num - def sample(self, num: int) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: - indeces = np.random.choice(len(self.states), num) - return self.states[indeces], self.actions[indeces], self.rewards[indeces], self.next_states[indeces], self.is_finished[indeces] + def sample( + self, + batch_size: int, + return_observation: bool = False, + cluster_size: int = 1 + ) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: + # TODO: add warning if batch_size % cluster_size != 0 + indeces = np.random.choice(len(self.states) - (cluster_size - 1), batch_size//cluster_size) + indeces = np.stack([indeces + i for i in range(cluster_size)]).flatten(order='F') + o = self.states[indeces] if not return_observation else self.observations[indeces] + return o, self.actions[indeces], self.rewards[indeces], self.next_states[ + indeces], self.is_finished[indeces] diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index 8f19064..2aed382 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -2,14 +2,14 @@ import gym import numpy as np -from dm_control.composer.environment import Environment as dmEnv +from dm_env import Environment as dmEnv from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.utils.dm_control import ActionDiscritizer, decode_dm_ts from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout -def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> Rollout: +def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, obs_res: t.Optional[t.Tuple[int, int]] = None) -> Rollout: s, a, r, n, f, o = [], [], [], [], [], [] match env: @@ -21,6 +21,7 @@ def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, save_ if agent is None: agent = RandomAgent(env) + while not terminated: action = agent.get_action(state) @@ -38,26 +39,26 @@ def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, save_ n.append(new_state) f.append(terminated) - if save_obs and isinstance(env, dmEnv): - o.append(env.physics.render(128, 128, camera_id=0)) + if obs_res is not None and isinstance(env, dmEnv): + o.append(env.physics.render(*obs_res, camera_id=0)) state = new_state match env: case gym.Env(): - obs = np.stack(list(env.render())) if save_obs else None + obs = np.stack(list(env.render())) if obs_res is not None else None case dmEnv(): - obs = np.array(o) if save_obs else None + obs = np.array(o) if obs_res is not None else None return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) -def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, save_obs: bool = False) -> t.List[Rollout]: +def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, obs_res: bool = False) -> t.List[Rollout]: # TODO: paralelyze rollouts = [] for _ in range(num): - rollouts.append(collect_rollout(env, agent, save_obs)) + rollouts.append(collect_rollout(env, agent, obs_res)) return rollouts -def fillup_replay_buffer(env: gym.Env, rep_buffer: ReplayBuffer, num: int): +def fillup_replay_buffer(env: gym.Env, rep_buffer: ReplayBuffer, num: int, obs_res: t.Optional[t.Tuple[int, int]] = None): # TODO: paralelyze while not rep_buffer.can_sample(num): - rep_buffer.add_rollout(collect_rollout(env)) + rep_buffer.add_rollout(collect_rollout(env, obs_res=obs_res)) diff --git a/tests/test_replay_buffer.py b/tests/test_replay_buffer.py index beea74a..a08ab56 100644 --- a/tests/test_replay_buffer.py +++ b/tests/test_replay_buffer.py @@ -1,54 +1,85 @@ -import numpy as np import random + +import numpy as np from pytest import fixture -from rl_sandbox.utils.replay_buffer import ReplayBuffer +from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout + @fixture def rep_buf(): return ReplayBuffer() -def test_creation(rep_buf): + +def test_creation(rep_buf: ReplayBuffer): assert len(rep_buf.states) == 0 -def test_adding(rep_buf): + +def test_adding(rep_buf: ReplayBuffer): s = np.ones((3, 8)) - a = np.ones((3, 3)) + a = np.ones((3, 3), dtype=np.int32) r = np.ones((3)) - rep_buf.add_rollout(s, a, r) + n = np.ones((3, 8)) + f = np.zeros((3), dtype=np.bool8) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) assert len(rep_buf.states) == 3 assert len(rep_buf.actions) == 3 assert len(rep_buf.rewards) == 3 s = np.zeros((3, 8)) - a = np.zeros((3, 3)) + a = np.zeros((3, 3), dtype=np.int32) r = np.zeros((3)) - rep_buf.add_rollout(s, a, r) + n = np.zeros((3, 8)) + f = np.zeros((3), dtype=np.bool8) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) assert len(rep_buf.states) == 6 assert len(rep_buf.actions) == 6 assert len(rep_buf.rewards) == 6 -def test_can_sample(rep_buf): + +def test_can_sample(rep_buf: ReplayBuffer): assert rep_buf.can_sample(1) == False s = np.ones((3, 8)) - a = np.ones((3, 3)) + a = np.zeros((3, 3), dtype=np.int32) r = np.ones((3)) - rep_buf.add_rollout(s, a, r) + n = np.zeros((3, 8)) + f = np.zeros((3), dtype=np.bool8) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) assert rep_buf.can_sample(5) == False assert rep_buf.can_sample(1) == True - rep_buf.add_rollout(s, a, r) + rep_buf.add_rollout(Rollout(s, a, r, n, f)) assert rep_buf.can_sample(5) == True -def test_sampling(rep_buf): - for i in range(1, 5): - rep_buf.add_rollout(np.ones((1,3)), np.ones((1,2)), i*np.ones((1))) + +def test_sampling(rep_buf: ReplayBuffer): + for i in range(5): + rep_buf.add_rollout( + Rollout(np.ones((1, 3)), np.ones((1, 2), dtype=np.int32), i * np.ones((1)), + np.ones((3, 8)), np.zeros((3), dtype=np.bool8))) random.seed(42) - _, _, r = rep_buf.sample(3) - assert (r == [3, 2, 4]).all() + _, _, r, _, _ = rep_buf.sample(3) + assert (r == [1, 0, 3]).all() + + +def test_cluster_sampling(rep_buf: ReplayBuffer): + for i in range(5): + rep_buf.add_rollout( + Rollout(np.ones((1, 3)), np.ones((1, 2), dtype=np.int32), i * np.ones((1)), + np.ones((3, 8)), np.zeros((3), dtype=np.bool8))) + + random.seed(42) + _, _, r, _, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [1, 2, 3, 4]).all() + + _, _, r, _, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [0, 1, 1, 2]).all() + + _, _, r, _, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [2, 3, 2, 3]).all() From 7907e33b10a682d1eda1712412e243bb341e67f9 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 8 Nov 2022 00:18:25 +0000 Subject: [PATCH 007/106] Written skeleton for Dreamer and important future notes --- rl_sandbox/agents/dreamer_v2.py | 219 ++++++++++++++++++++++++++++++ rl_sandbox/utils/replay_buffer.py | 1 + rl_sandbox/vision/vq_vae.py | 4 +- 3 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 rl_sandbox/agents/dreamer_v2.py diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py new file mode 100644 index 0000000..3c6d197 --- /dev/null +++ b/rl_sandbox/agents/dreamer_v2.py @@ -0,0 +1,219 @@ +from collections import defaultdict + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.utils.replay_buffer import (Action, Actions, Observations, + Rewards, State, States, + TerminationFlags) +from rl_sandbox.vision.vq_vae import VQ_VAE + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + """ + + def __init__(self, latent_dim, hidden_size, actions_num): + super().__init__() + + self.gru = nn.GRU(input_size=latent_dim + actions_num, hidden_size=hidden_size) + + def forward(self, h_prev, z, a): + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on current step + Returns 'h_next' <- the next next of the world + """ + + _, h_n = self.gru(torch.concat([z, a]), h_prev) + # NOTE: except deterministic step h_t, model should also return stochastic state concatenated + # NOTE: to add stoshasticity for internal state, ensemble of MLP's is used + return h_n + + +# NOTE: In Dreamer ELU is used everywhere as activation func +# NOTE: In Dreamer 48**(lvl) filter size is used, 4 level of convolution, +# Layer Normalizatin instead of Batch +# NOTE: residual blocks are not used inside dreamer +class Encoder(nn.Module): + + def __init__(self, kernel_sizes=[4, 4, 4, 4]): + super().__init__() + layers = [] + + channel_step = 48 + in_channels = 3 + for i, k in enumerate(kernel_sizes): + out_channels = 2**i * channel_step + layers.append( + nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2, padding=1)) + layers.append(nn.ELU(inplace=True)) + layers.append(nn.LayerNorm(out_channels)) + in_channels = out_channels + layers.append(nn.Flatten()) + self.net = nn.Sequential(*layers) + + def forward(self, X): + return self.net(X) + +class Decoder(nn.Module): + + def __init__(self, kernel_sizes=[4, 4, 4, 4]): + super().__init__() + layers = [] + + channel_step = 48 + in_channels = 2**(len(kernel_sizes)-1) *channel_step + for i, k in enumerate(kernel_sizes): + out_channels = 2**(len(kernel_sizes) - i - 2) * channel_step + if out_channels == channel_step: + out_channels = 3 + layers.append(nn.ConvTranspose2d(in_channels, out_channels, + kernel_size=k, + stride=2, + padding=1 + )) + layers.append(nn.ELU(inplace=True)) + layers.append(nn.LayerNorm(out_channels)) + in_channels = out_channels + self.net = nn.Sequential(*layers) + + def forward(self, X): + return self.net(X) + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, kl_loss_scale): + self.kl_beta = kl_loss_scale + self.rssm_dim = rssm_dim + self.cluster_size = batch_cluster_size + self.recurrent_model = RSSM(latent_dim, rssm_dim, actions_num) + # NOTE: In Dreamer paper VQ-VAE has MLP after conv2d to get 1d embedding, + # which is concatenated with deterministic state and only after that + # sampled into discrete one-hot encoding (using TensorFlow.Distribution OneHotCategorical) + # self.representation_network = VQ_VAE( + # latent_dim=latent_dim, + # latent_space_size=latent_classes) # actually only 'encoder' part of VAE + self.encoder = Encoder() + self.image_predictor = Decoder() + # self.image_predictor = 'decoder' part of VAE + # FIXME: will not work until VQ-VAE internal embedding will not be changed from 2d to 1d + # FIXME: in Dreamer paper it is 4 hidden layers with 400 hidden units + # FIXME: in Dramer paper it has Layer Normalization after Dense + self.transition_network = fc_nn_generator(rssm_dim, + latent_dim, + hidden_size=128, + num_layers=3) + self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim, + 1, + hidden_size=128, + num_layers=3) + self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim, + 1, + hidden_size=128, + num_layers=3) + + self.optimizer = torch.optim.Adam(self.representation_network.parameters(), + lr=2e-4) + + def forward(self, X): + pass + + def train(self, s: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + is_finished: torch.Tensor): + b, h, w, _ = s.shape # s <- BxHxWx3 + s = s.view(-1, self.cluster_size, h, w, 3) + a = a.view(-1, self.cluster_size, a.shape[1]) + r = r.view(-1, self.cluster_size, 1) + f = is_finished.view(-1, self.cluster_size, 1) + + h_prev = torch.zeros((b, self.rssm_dim)) + losses = defaultdict(lambda: torch.zeros(1)) + + embed = self.encoder(s) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + embed_t, a_t, r_t, f_t = embed[:, t].unsqueeze(0), a[:, t].unsqueeze( + 0), r[:, t].unsqueeze(0), f[:, t].unsqueeze(0) + + # TODO: add in the future h_t into representation network + # NOTE: can be moved out of the loop, *embed* is calculated solely by image + # s_t_r, z_t, e_t = self.representation_network(s_t) + h_t = self.recurrent_model(h_prev, z_t, a_t) + + r_t_pred = self.reward_predictor(torch.concat([h_t, z_t])) + f_t_pred = self.discount_predictor(torch.concat([h_t, z_t])) + z_t_prior = self.transition_network(h_t) + + vae_losses = self.representation_network.calculate_loss(s_t, s_t_r, z_t, e_t) + # NOTE: regularization loss from VQ-VAE is not used in Dreamer paper + losses['loss_reconstruction'] = vae_losses['loss_rec'] + losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) + losses['loss_discount_pred'] += F.cross_entropy(f_t, f_t_pred) + # TODO: add KL divergence loss between transition predictor and representation model + # NOTE: remember about different learning rate for prior and posterior + # NOTE: VQ-VAE should be changed to output the softmax of how close z is to each e, + # so it can be used to count as probability for each distribution to calculate + # the KL divergence + # NOTE: DreamerV2 uses TensorFlow.Probability to calculate KL divergence + losses['loss_kl_reg'] += self.kl_beta * 0 + + h_prev = h_t + + loss = torch.Tensor(0) + for l in losses.values(): + loss += l + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return {l: val.detach() for l, val in losses.items()} + + +class DreamerV2(RlAgent): + + def __init__(self, + obs_space_num: int, + actions_num: int, + batch_cluster_size: int, + latent_dim: int, + latent_classes: int, + rssm_dim: int, + discount_factor: float, + kl_loss_scale: float, + device_type: str = 'cpu'): + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + self.gamma = discount_factor + + self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, + rssm_dim, actions_num, + kl_loss_scale).to(device_type) + + def get_action(self, obs: State) -> Action: + return self.actions_num + + def from_np(self, arr: np.ndarray): + return torch.from_numpy(arr).to(self.device_type) + + def train(self, s: Observations, a: Actions, r: Rewards, next: States, + is_finished: TerminationFlags): + # NOTE: next is currently incorrect (state instead of img), but also unused + + s = self.from_np(s) + a = self.from_np(a) + r = self.from_np(r) + next = self.from_np(next) + is_finished = self.from_np(is_finished) + + self.world_model.train(s, a, r, is_finished) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 1d56008..26a516e 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -80,6 +80,7 @@ def sample( cluster_size: int = 1 ) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: # TODO: add warning if batch_size % cluster_size != 0 + # FIXME: currently doesn't take into account discontinuations between between rollouts indeces = np.random.choice(len(self.states) - (cluster_size - 1), batch_size//cluster_size) indeces = np.stack([indeces + i for i in range(cluster_size)]).flatten(order='F') o = self.states[indeces] if not return_observation else self.observations[indeces] diff --git a/rl_sandbox/vision/vq_vae.py b/rl_sandbox/vision/vq_vae.py index 9a2552c..c1efdc4 100644 --- a/rl_sandbox/vision/vq_vae.py +++ b/rl_sandbox/vision/vq_vae.py @@ -78,9 +78,7 @@ def quantize(self, z): flatten_quantized_latents = torch.index_select(self.latent_space, 0, ks) # BHWxD e = flatten_quantized_latents.view(latents.shape).permute((0, 3, 1, 2)).to(memory_format=torch.contiguous_format) - z.retain_grad() - e.grad = z.grad - return e + return e + (z - z.detach()) def forward(self, X): From a82032873f7a8ebc60125b09e17794c65c1d2892 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 8 Nov 2022 17:44:09 +0000 Subject: [PATCH 008/106] Implemented World Model mainloop and loss calculation --- config/agent/dreamer_v2.yaml | 2 +- config/config.yaml | 6 +- main.py | 16 +- pyproject.toml | 4 + rl_sandbox/agents/dreamer_v2.py | 268 +++++++++++++++++++++++--------- 5 files changed, 216 insertions(+), 80 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index ca2a661..8c6086f 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -3,5 +3,5 @@ discount_factor: 0.995 batch_cluster_size: 8 latent_dim: 32 latent_classes: 32 -rssm_dim: 600 +rssm_dim: 200 kl_loss_scale: 0.1 diff --git a/config/config.yaml b/config/config.yaml index d748bab..fb38c36 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,5 +1,5 @@ defaults: - - agent/dqn_agent + - agent/dreamer_v2 - _self_ env: @@ -9,10 +9,10 @@ env: domain_name: cartpole task_name: swingup run_on_pixels: true - obs_res: [128, 128] + obs_res: [64, 64] seed: 42 -device_type: mps +device_type: cpu training: epochs: 5000 diff --git a/main.py b/main.py index 8211034..48562d2 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import numpy as np from dm_control import suite from omegaconf import DictConfig, OmegaConf +import numpy as np from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm @@ -48,9 +49,10 @@ def main(cfg: DictConfig): obs_space_num = env.observation_space.shape[0] exploration_agent = RandomAgent(env) + # FIXME: currently action is 1 value, but not one-hot encoding agent = hydra.utils.instantiate(cfg.agent, obs_space_num=obs_space_num, - actions_num=action_disritizer.shape, + actions_num=(1), device_type=cfg.device_type) writer = SummaryWriter() @@ -88,12 +90,20 @@ def main(cfg: DictConfig): buff.add_sample(state, action, reward, new_state, terminated, obs) + # FIXME: unintuitive that batch_size is now number of total + # samples, but not amount of sequences for recurrent model s, a, r, n, f = buff.sample(cfg.training.batch_size, return_observation=cfg.env.run_on_pixels, cluster_size=cfg.agent.get('batch_cluster_size', 1)) - loss = agent.train(s, a, r, n, f) - writer.add_scalar('train/loss', loss, global_step) + losses = agent.train(s, a, r, n, f) + if isinstance(losses, np.ndarray): + writer.add_scalar('train/loss', loss, global_step) + elif isinstance(losses, dict): + for loss_name, loss in losses.items(): + writer.add_scalar(f'train/{loss_name}', loss, global_step) + else: + raise RuntimeError("AAAA, very bad") global_step += 1 ### Validation diff --git a/pyproject.toml b/pyproject.toml index 8e6f95e..4484501 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,10 @@ description = 'Sandbox for my RL experiments' authors = ['Roman Milishchuk '] packages = [{include = 'rl_sandbox'}] +[tool.yapf] +based_on_style = "pep8" +column_limit = 90 + [tool.poetry.dependencies] python = "^3.10" numpy = '*' diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 3c6d197..bf6d6fe 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,3 +1,4 @@ +import itertools from collections import defaultdict import numpy as np @@ -10,31 +11,136 @@ from rl_sandbox.utils.replay_buffer import (Action, Actions, Observations, Rewards, State, States, TerminationFlags) -from rl_sandbox.vision.vq_vae import VQ_VAE + + +class View(nn.Module): + + def __init__(self, shape): + super().__init__() + self.shape = shape + + def forward(self, x): + return x.view(*self.shape) + + +class DebugShapeLayer(nn.Module): + + def __init__(self, note=""): + super().__init__() + self.note = note + + def forward(self, x): + if len(self.note): + print(self.note, x.shape) + else: + print(x.shape) + return x + + +class Quantize(nn.Module): + + def forward(self, logits): + dist = torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( + logits=logits) + return dist.rsample() class RSSM(nn.Module): """ Recurrent State Space Model - """ - def __init__(self, latent_dim, hidden_size, actions_num): - super().__init__() + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \ x_1 \ x_2 \ x_3 + | \ | ^ | \ | ^ | \ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \ | | \ | | \ | | + Ensemble \ | | Ensemble \ | | Ensemble \ | | + \| | \| | \| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| - self.gru = nn.GRU(input_size=latent_dim + actions_num, hidden_size=hidden_size) + """ - def forward(self, h_prev, z, a): + def __init__(self, latent_dim, hidden_size, actions_num, categories_num): + super().__init__() + self.latent_dim = latent_dim + self.categories_num = categories_num + self.ensemble_num = 5 + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * categories_num + actions_num, hidden_size), # Dreamer 'img_in' + nn.LayerNorm(hidden_size), + ) + self.determ_recurrent = nn.GRU(input_size=hidden_size, + hidden_size=hidden_size) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + nn.LayerNorm(hidden_size), + nn.Linear(hidden_size, + latent_dim * self.categories_num), # Dreamer 'img_dist_{k}' + View((-1, latent_dim, self.categories_num))) + for _ in range(self.ensemble_num) + ]) + + # For observation we do not have ensemble + # FIXME: very band magic number + img_sz = 4 * 384 # 384*2x2 + self.stoch_net = nn.Sequential( + nn.Linear(hidden_size + img_sz, hidden_size), + nn.LayerNorm(hidden_size), + nn.Linear(hidden_size, hidden_size), # Dreamer 'obs_out' + nn.LayerNorm(hidden_size), + nn.Linear(hidden_size, + latent_dim * self.categories_num), # Dreamer 'obs_dist' + View((-1, latent_dim, self.categories_num)), + # NOTE: Maybe worth having some LogSoftMax as activation + # before using input as logits for distribution + # Quantize() + ) + + def estimate_stochastic_latent(self, prev_determ): + logits_per_model = torch.stack( + [model(prev_determ) for model in self.ensemble_prior_estimator]) + # NOTE: Maybe something smarter can be used instead of + # taking only one random between all ensembles + index = torch.randint(0, self.ensemble_num, ()) + return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( + logits=logits_per_model[index]) + + def forward(self, h_prev: tuple[torch.Tensor, torch.Tensor], embed, action): """ 'h' <- internal state of the world 'z' <- latent embedding of current observation - 'a' <- action taken on current step + 'a' <- action taken on prev step Returns 'h_next' <- the next next of the world """ - _, h_n = self.gru(torch.concat([z, a]), h_prev) - # NOTE: except deterministic step h_t, model should also return stochastic state concatenated - # NOTE: to add stoshasticity for internal state, ensemble of MLP's is used - return h_n + # Use zero vector for prev_state of first + deter_prev, stoch_prev = h_prev + x = self.pre_determ_recurrent(torch.concat([stoch_prev, action], dim=2)) + _, determ = self.determ_recurrent(x, deter_prev) + + # used for KL divergence + prior_stoch_dist = self.estimate_stochastic_latent(determ) + + posterior_stoch_logits = self.stoch_net(torch.concat([determ, embed], + dim=2)) # Dreamer 'obs_out' + posterior_stoch_dist = torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( + logits=posterior_stoch_logits) + + return [determ, prior_stoch_dist, posterior_stoch_dist] # NOTE: In Dreamer ELU is used everywhere as activation func @@ -51,10 +157,10 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): in_channels = 3 for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step - layers.append( - nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2, padding=1)) + layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) layers.append(nn.ELU(inplace=True)) - layers.append(nn.LayerNorm(out_channels)) + # FIXME: change to layer norm when sizes will be known + layers.append(nn.BatchNorm2d(out_channels)) in_channels = out_channels layers.append(nn.Flatten()) self.net = nn.Sequential(*layers) @@ -62,66 +168,75 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): def forward(self, X): return self.net(X) + class Decoder(nn.Module): - def __init__(self, kernel_sizes=[4, 4, 4, 4]): + def __init__(self, kernel_sizes=[5, 5, 6, 6]): super().__init__() layers = [] + self.channel_step = 48 + # 2**(len(kernel_sizes)-1)*channel_step + self.convin = nn.Linear(32*32, 32*self.channel_step) - channel_step = 48 - in_channels = 2**(len(kernel_sizes)-1) *channel_step + in_channels = 32*self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step for i, k in enumerate(kernel_sizes): - out_channels = 2**(len(kernel_sizes) - i - 2) * channel_step - if out_channels == channel_step: + out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step + if i == len(kernel_sizes) - 1: out_channels = 3 - layers.append(nn.ConvTranspose2d(in_channels, out_channels, - kernel_size=k, - stride=2, - padding=1 - )) - layers.append(nn.ELU(inplace=True)) - layers.append(nn.LayerNorm(out_channels)) + layers.append( + nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) + else: + layers.append( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) + layers.append(nn.ELU(inplace=True)) + layers.append(nn.BatchNorm2d(out_channels)) in_channels = out_channels self.net = nn.Sequential(*layers) def forward(self, X): - return self.net(X) + x = self.convin(X) + x = x.view(-1, 32*self.channel_step, 1, 1) + return self.net(x) class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale): + super().__init__() self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes self.cluster_size = batch_cluster_size - self.recurrent_model = RSSM(latent_dim, rssm_dim, actions_num) - # NOTE: In Dreamer paper VQ-VAE has MLP after conv2d to get 1d embedding, - # which is concatenated with deterministic state and only after that - # sampled into discrete one-hot encoding (using TensorFlow.Distribution OneHotCategorical) - # self.representation_network = VQ_VAE( - # latent_dim=latent_dim, - # latent_space_size=latent_classes) # actually only 'encoder' part of VAE + # kl loss balancing (prior/posterior) + self.alpha = 0.8 + + self.recurrent_model = RSSM(latent_dim, + rssm_dim, + actions_num, + categories_num=latent_classes) self.encoder = Encoder() + from torchsummary import summary + + # summary(self.encoder, input_size=(3, 64, 64)) self.image_predictor = Decoder() - # self.image_predictor = 'decoder' part of VAE - # FIXME: will not work until VQ-VAE internal embedding will not be changed from 2d to 1d # FIXME: in Dreamer paper it is 4 hidden layers with 400 hidden units # FIXME: in Dramer paper it has Layer Normalization after Dense - self.transition_network = fc_nn_generator(rssm_dim, - latent_dim, - hidden_size=128, - num_layers=3) - self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim, + self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=128, num_layers=3) - self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim, + self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=128, - num_layers=3) + num_layers=3, + final_activation=nn.Sigmoid) - self.optimizer = torch.optim.Adam(self.representation_network.parameters(), + self.optimizer = torch.optim.Adam(itertools.chain( + self.recurrent_model.parameters(), self.encoder.parameters(), + self.image_predictor.parameters(), self.reward_predictor.parameters(), + self.discount_predictor.parameters()), lr=2e-4) def forward(self, X): @@ -130,45 +245,50 @@ def forward(self, X): def train(self, s: torch.Tensor, a: torch.Tensor, r: torch.Tensor, is_finished: torch.Tensor): b, h, w, _ = s.shape # s <- BxHxWx3 - s = s.view(-1, self.cluster_size, h, w, 3) + + s = ((s.type(torch.float32) / 255.0) - 0.5).permute(0, 3, 1, 2) + embed = self.encoder(s) + embed = embed.view(b // self.cluster_size, self.cluster_size, -1) + + s = s.view(-1, self.cluster_size, 3, h, w) a = a.view(-1, self.cluster_size, a.shape[1]) r = r.view(-1, self.cluster_size, 1) f = is_finished.view(-1, self.cluster_size, 1) - h_prev = torch.zeros((b, self.rssm_dim)) + h_prev = [ + torch.zeros((1, b // self.cluster_size, self.rssm_dim)), + torch.zeros((1, b // self.cluster_size, self.latent_dim*self.latent_classes)) + ] losses = defaultdict(lambda: torch.zeros(1)) - embed = self.encoder(s) + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + Dist = torch.distributions.OneHotCategoricalStraightThrough + return self.kl_beta *(self.alpha * KL_(dist1, Dist(logits=dist2.logits.detach())).mean() + + (1 - self.alpha) * KL_(Dist(logits=dist1.logits.detach()), dist2).mean()) + for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - embed_t, a_t, r_t, f_t = embed[:, t].unsqueeze(0), a[:, t].unsqueeze( - 0), r[:, t].unsqueeze(0), f[:, t].unsqueeze(0) + x_t, embed_t, a_t, r_t, f_t = s[:, t], embed[:, t].unsqueeze( + 0), a[:, t].unsqueeze(0), r[:, t], f[:, t] + + determ_t, prior_stoch_dist, posterior_stoch_dist = self.recurrent_model( + h_prev, embed_t, a_t) + posterior_stoch = posterior_stoch_dist.rsample().reshape(-1, self.latent_dim*self.latent_classes) - # TODO: add in the future h_t into representation network - # NOTE: can be moved out of the loop, *embed* is calculated solely by image - # s_t_r, z_t, e_t = self.representation_network(s_t) - h_t = self.recurrent_model(h_prev, z_t, a_t) + r_t_pred = self.reward_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + f_t_pred = self.discount_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - r_t_pred = self.reward_predictor(torch.concat([h_t, z_t])) - f_t_pred = self.discount_predictor(torch.concat([h_t, z_t])) - z_t_prior = self.transition_network(h_t) + x_r = self.image_predictor(posterior_stoch) - vae_losses = self.representation_network.calculate_loss(s_t, s_t_r, z_t, e_t) - # NOTE: regularization loss from VQ-VAE is not used in Dreamer paper - losses['loss_reconstruction'] = vae_losses['loss_rec'] + losses['loss_reconstruction'] = nn.functional.mse_loss(x_t, x_r) losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) - losses['loss_discount_pred'] += F.cross_entropy(f_t, f_t_pred) - # TODO: add KL divergence loss between transition predictor and representation model - # NOTE: remember about different learning rate for prior and posterior - # NOTE: VQ-VAE should be changed to output the softmax of how close z is to each e, - # so it can be used to count as probability for each distribution to calculate - # the KL divergence - # NOTE: DreamerV2 uses TensorFlow.Probability to calculate KL divergence - losses['loss_kl_reg'] += self.kl_beta * 0 - - h_prev = h_t - - loss = torch.Tensor(0) + losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), f_t_pred) + losses['loss_kl_reg'] += KL(prior_stoch_dist, posterior_stoch_dist) + + h_prev = [determ_t, posterior_stoch.unsqueeze(0)] + + loss = torch.Tensor(1) for l in losses.values(): loss += l @@ -204,11 +324,13 @@ def get_action(self, obs: State) -> Action: return self.actions_num def from_np(self, arr: np.ndarray): - return torch.from_numpy(arr).to(self.device_type) + return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) def train(self, s: Observations, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): # NOTE: next is currently incorrect (state instead of img), but also unused + # FIXME: next should be correct, as World model is trained on triplets + # (h_prev, action, next_state) s = self.from_np(s) a = self.from_np(a) @@ -216,4 +338,4 @@ def train(self, s: Observations, a: Actions, r: Rewards, next: States, next = self.from_np(next) is_finished = self.from_np(is_finished) - self.world_model.train(s, a, r, is_finished) + return self.world_model.train(s, a, r, is_finished) From ae20fcb04911beda70a19d3d993622bc3373f79f Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 10 Nov 2022 14:26:53 +0000 Subject: [PATCH 009/106] Added generation of imaginative rollouts --- config/agent/dreamer_v2.yaml | 7 +- main.py | 14 +- rl_sandbox/agents/dreamer_v2.py | 181 ++++++++++++++++++------- rl_sandbox/utils/replay_buffer.py | 7 +- rl_sandbox/utils/rollout_generation.py | 34 ++++- 5 files changed, 180 insertions(+), 63 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index 8c6086f..ed44c63 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -1,7 +1,12 @@ _target_: rl_sandbox.agents.DreamerV2 -discount_factor: 0.995 +# World model parameters batch_cluster_size: 8 latent_dim: 32 latent_classes: 32 rssm_dim: 200 kl_loss_scale: 0.1 + +# ActorCritic parameters +imagination_horizon: 15 +discount_factor: 0.995 +#critic_update_interval: 100 diff --git a/main.py b/main.py index 48562d2..1750e44 100644 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ def main(cfg: DictConfig): buff = ReplayBuffer() obs_res = cfg.env.obs_res if cfg.env.run_on_pixels else None - fillup_replay_buffer(env, buff, cfg.training.batch_size, obs_res=obs_res) + fillup_replay_buffer(env, buff, cfg.training.batch_size, obs_res=obs_res, run_on_obs=cfg.env.run_on_pixels) action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=10) metrics_evaluator = MetricsEvaluator() @@ -81,17 +81,19 @@ def main(cfg: DictConfig): match cfg.env.type: case "dm_control": new_state, reward, terminated = decode_dm_ts(env.step(action_disritizer.undiscretize(action))) - # FIXME: if run_on_pixels next_state should also be observation - obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None + new_obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None case "gym": new_state, reward, terminated, _, _ = env.step(action) action = action_disritizer.undiscretize(action) obs = None - buff.add_sample(state, action, reward, new_state, terminated, obs) + if cfg.env.run_on_pixels: + buff.add_sample(obs, action, reward, new_obs, terminated, obs) + else: + buff.add_sample(state, action, reward, new_state, terminated, obs) - # FIXME: unintuitive that batch_size is now number of total - # samples, but not amount of sequences for recurrent model + # NOTE: unintuitive that batch_size is now number of total + # samples, but not amount of sequences for recurrent model s, a, r, n, f = buff.sample(cfg.training.batch_size, return_observation=cfg.env.run_on_pixels, cluster_size=cfg.agent.get('batch_cluster_size', 1)) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index bf6d6fe..f27e52a 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,4 +1,5 @@ import itertools +import typing as t from collections import defaultdict import numpy as np @@ -9,8 +10,7 @@ from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.utils.replay_buffer import (Action, Actions, Observations, - Rewards, State, States, - TerminationFlags) + Rewards, State, TerminationFlags) class View(nn.Module): @@ -73,10 +73,12 @@ def __init__(self, latent_dim, hidden_size, actions_num, categories_num): self.latent_dim = latent_dim self.categories_num = categories_num self.ensemble_num = 5 + self.hidden_size = hidden_size # Calculate deterministic state from prev stochastic, prev action and prev deterministic self.pre_determ_recurrent = nn.Sequential( - nn.Linear(latent_dim * categories_num + actions_num, hidden_size), # Dreamer 'img_in' + nn.Linear(latent_dim * categories_num + actions_num, + hidden_size), # Dreamer 'img_in' nn.LayerNorm(hidden_size), ) self.determ_recurrent = nn.GRU(input_size=hidden_size, @@ -119,6 +121,20 @@ def estimate_stochastic_latent(self, prev_determ): return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( logits=logits_per_model[index]) + def predict_next(self, + stoch_latent, + action, + deter_state: t.Optional[torch.Tensor] = None): + if deter_state is None: + deter_state = torch.zeros(*stoch_latent.shape[:2], self.hidden_size).to( + next(self.stoch_net.parameters()).device) + x = self.pre_determ_recurrent(torch.concat([stoch_latent, action], dim=2)) + _, determ = self.determ_recurrent(x, deter_state) + + # used for KL divergence + predicted_stoch_latent = self.estimate_stochastic_latent(determ) + return deter_state, predicted_stoch_latent + def forward(self, h_prev: tuple[torch.Tensor, torch.Tensor], embed, action): """ 'h' <- internal state of the world @@ -129,11 +145,9 @@ def forward(self, h_prev: tuple[torch.Tensor, torch.Tensor], embed, action): # Use zero vector for prev_state of first deter_prev, stoch_prev = h_prev - x = self.pre_determ_recurrent(torch.concat([stoch_prev, action], dim=2)) - _, determ = self.determ_recurrent(x, deter_prev) - - # used for KL divergence - prior_stoch_dist = self.estimate_stochastic_latent(determ) + determ, prior_stoch_dist = self.predict_next(stoch_prev, + action, + deter_state=deter_prev) posterior_stoch_logits = self.stoch_net(torch.concat([determ, embed], dim=2)) # Dreamer 'obs_out' @@ -143,9 +157,6 @@ def forward(self, h_prev: tuple[torch.Tensor, torch.Tensor], embed, action): return [determ, prior_stoch_dist, posterior_stoch_dist] -# NOTE: In Dreamer ELU is used everywhere as activation func -# NOTE: In Dreamer 48**(lvl) filter size is used, 4 level of convolution, -# Layer Normalizatin instead of Batch # NOTE: residual blocks are not used inside dreamer class Encoder(nn.Module): @@ -176,18 +187,18 @@ def __init__(self, kernel_sizes=[5, 5, 6, 6]): layers = [] self.channel_step = 48 # 2**(len(kernel_sizes)-1)*channel_step - self.convin = nn.Linear(32*32, 32*self.channel_step) + self.convin = nn.Linear(32 * 32, 32 * self.channel_step) - in_channels = 32*self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step + in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step for i, k in enumerate(kernel_sizes): out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step if i == len(kernel_sizes) - 1: out_channels = 3 - layers.append( - nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) + layers.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) else: layers.append( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, + stride=2)) layers.append(nn.ELU(inplace=True)) layers.append(nn.BatchNorm2d(out_channels)) in_channels = out_channels @@ -195,7 +206,7 @@ def __init__(self, kernel_sizes=[5, 5, 6, 6]): def forward(self, X): x = self.convin(X) - x = x.view(-1, 32*self.channel_step, 1, 1) + x = x.view(-1, 32 * self.channel_step, 1, 1) return self.net(x) @@ -217,20 +228,15 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, categories_num=latent_classes) self.encoder = Encoder() - from torchsummary import summary - - # summary(self.encoder, input_size=(3, 64, 64)) self.image_predictor = Decoder() - # FIXME: in Dreamer paper it is 4 hidden layers with 400 hidden units - # FIXME: in Dramer paper it has Layer Normalization after Dense self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, - hidden_size=128, - num_layers=3) + hidden_size=400, + num_layers=4) self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, - hidden_size=128, - num_layers=3, + hidden_size=400, + num_layers=4, final_activation=nn.Sigmoid) self.optimizer = torch.optim.Adam(itertools.chain( @@ -239,54 +245,69 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.discount_predictor.parameters()), lr=2e-4) - def forward(self, X): - pass + def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): + determ_state, next_repr_dist = self.recurrent_model.predict_next( + latent_repr.unsqueeze(0), action.unsqueeze(0), world_state) + + next_repr = next_repr_dist.sample().reshape(-1, self.latent_dim * self.latent_classes) + reward = self.reward_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)) + is_finished = self.discount_predictor( torch.concat([determ_state.squeeze(0), next_repr], dim=1)) + return determ_state, next_repr, reward, is_finished - def train(self, s: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + def train(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, is_finished: torch.Tensor): - b, h, w, _ = s.shape # s <- BxHxWx3 + b, h, w, _ = obs.shape # s <- BxHxWx3 - s = ((s.type(torch.float32) / 255.0) - 0.5).permute(0, 3, 1, 2) - embed = self.encoder(s) + obs = ((obs.type(torch.float32) / 255.0) - 0.5).permute(0, 3, 1, 2) + embed = self.encoder(obs) embed = embed.view(b // self.cluster_size, self.cluster_size, -1) - s = s.view(-1, self.cluster_size, 3, h, w) + obs = obs.view(-1, self.cluster_size, 3, h, w) a = a.view(-1, self.cluster_size, a.shape[1]) r = r.view(-1, self.cluster_size, 1) f = is_finished.view(-1, self.cluster_size, 1) h_prev = [ torch.zeros((1, b // self.cluster_size, self.rssm_dim)), - torch.zeros((1, b // self.cluster_size, self.latent_dim*self.latent_classes)) + torch.zeros( + (1, b // self.cluster_size, self.latent_dim * self.latent_classes)) ] losses = defaultdict(lambda: torch.zeros(1)) def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence Dist = torch.distributions.OneHotCategoricalStraightThrough - return self.kl_beta *(self.alpha * KL_(dist1, Dist(logits=dist2.logits.detach())).mean() + - (1 - self.alpha) * KL_(Dist(logits=dist1.logits.detach()), dist2).mean()) + return self.kl_beta * ( + self.alpha * KL_(dist1, Dist(logits=dist2.logits.detach())).mean() + + (1 - self.alpha) * KL_(Dist(logits=dist1.logits.detach()), dist2).mean()) + + latent_vars = [] for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - x_t, embed_t, a_t, r_t, f_t = s[:, t], embed[:, t].unsqueeze( + x_t, embed_t, a_t, r_t, f_t = obs[:, t], embed[:, t].unsqueeze( 0), a[:, t].unsqueeze(0), r[:, t], f[:, t] determ_t, prior_stoch_dist, posterior_stoch_dist = self.recurrent_model( h_prev, embed_t, a_t) - posterior_stoch = posterior_stoch_dist.rsample().reshape(-1, self.latent_dim*self.latent_classes) + posterior_stoch = posterior_stoch_dist.rsample().reshape( + -1, self.latent_dim * self.latent_classes) - r_t_pred = self.reward_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - f_t_pred = self.discount_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + r_t_pred = self.reward_predictor( + torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + f_t_pred = self.discount_predictor( + torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) x_r = self.image_predictor(posterior_stoch) losses['loss_reconstruction'] = nn.functional.mse_loss(x_t, x_r) losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) - losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), f_t_pred) + losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), + f_t_pred) losses['loss_kl_reg'] += KL(prior_stoch_dist, posterior_stoch_dist) h_prev = [determ_t, posterior_stoch.unsqueeze(0)] + latent_vars.append(posterior_stoch.detach()) loss = torch.Tensor(1) for l in losses.values(): @@ -296,7 +317,48 @@ def KL(dist1, dist2): loss.backward() self.optimizer.step() - return {l: val.detach() for l, val in losses.items()} + return {l: val.detach() for l, val in losses.items()}, torch.stack(latent_vars) + + +class ImaginativeActorCritic(nn.Module): + + def __init__(self): + super().__init__() + # mixing of reinforce and maximizing value func + self.rho = 0 # for dm_control it is zero in Dreamer (Atari 1) + # scale for entropy + self.eta = 1e-4 + # parameter for n-step return for value function + self.lambda_par = 0.95 + + # NOTE: stochastic + # MLP with ELU activation + # Output layer categorical distribution + self.actor = ... + + # NOTE: deterministic + # MLP with ELU activation + # Uses Q-network and Actor to chose actions + self.critic = ... + self.target_actor = ... + + def get_action(self, z): + b, _ = z.shape + return torch.ones((b, 1)) + + def estimate_value(self, z) -> torch.Tensor: + ... + + # NOTE: use target network for calculating + # labmda target + def lambda_return(self): + ... + + def calculate_loss(self): + # Critic <- Bellman return ( (V - V_lambda)^2 ) + # Bellman return for critic + # Actor <- reinforce + maximize value_func + maximize actor entropy + ... class DreamerV2(RlAgent): @@ -310,15 +372,33 @@ def __init__(self, rssm_dim: int, discount_factor: float, kl_loss_scale: float, + imagination_horizon: int, device_type: str = 'cpu'): + self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size self.actions_num = actions_num + self.H = imagination_horizon self.gamma = discount_factor self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale).to(device_type) + self.actor_critic = ImaginativeActorCritic() + + def imagine_trajectory( + self, + z_0) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + rollout = [] + world_state = None + z = z_0 + for _ in range(self.imagination_horizon): + a = self.actor_critic.get_action(z) + world_state, next_z, reward, is_finished = self.world_model.predict_next( + z, a, world_state) + rollout.append((z, a, next_z, reward, is_finished)) + z = next_z + return rollout def get_action(self, obs: State) -> Action: return self.actions_num @@ -326,11 +406,8 @@ def get_action(self, obs: State) -> Action: def from_np(self, arr: np.ndarray): return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) - def train(self, s: Observations, a: Actions, r: Rewards, next: States, + def train(self, s: Observations, a: Actions, r: Rewards, next: Observations, is_finished: TerminationFlags): - # NOTE: next is currently incorrect (state instead of img), but also unused - # FIXME: next should be correct, as World model is trained on triplets - # (h_prev, action, next_state) s = self.from_np(s) a = self.from_np(a) @@ -338,4 +415,14 @@ def train(self, s: Observations, a: Actions, r: Rewards, next: States, next = self.from_np(next) is_finished = self.from_np(is_finished) - return self.world_model.train(s, a, r, is_finished) + # take some latent embeddings as initial step + losses, discovered_latents = self.world_model.train(next, a, r, is_finished) + + perm = torch.randperm(discovered_latents.size(0)) + idx = perm[:12] + initial_states = discovered_latents[idx] + + for z_0 in initial_states: + rollout = self.imagine_trajectory(z_0) + + return losses diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 26a516e..f195904 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -1,16 +1,15 @@ import typing as t -from collections import deque from dataclasses import dataclass import numpy as np from nptyping import Bool, Float, Int, NDArray, Shape Observation = NDArray[Shape["*,*,3"], Int] -State = NDArray[Shape["*"], Float] +State = NDArray[Shape["*"], Float] | Observation Action = NDArray[Shape["*"], Int] Observations = NDArray[Shape["*,*,*,3"], Int] -States = NDArray[Shape["*,*"], Float] +States = NDArray[Shape["*,*"], Float] | Observations Actions = NDArray[Shape["*,*"], Int] Rewards = NDArray[Shape["*"], Float] TerminationFlags = NDArray[Shape["*"], Bool] @@ -64,7 +63,7 @@ def add_rollout(self, rollout: Rollout): self.observations = self.observations[:self.max_len] def add_sample(self, s: State, a: Action, r: float, n: State, f: bool, - o: t.Optional[Observation]): + o: t.Optional[Observation] = None): rollout = Rollout(np.array([s]), np.expand_dims(np.array([a]), 0), np.array([r], dtype=np.float32), np.array([n]), np.array([f]), np.array([o]) if o is not None else None) diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index 2aed382..e7ac46b 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -9,14 +9,28 @@ from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout -def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, obs_res: t.Optional[t.Tuple[int, int]] = None) -> Rollout: +# FIXME: whole function duplicates a lot of code from main.py +def collect_rollout(env: gym.Env | dmEnv, + agent: t.Optional[t.Any] = None, + obs_res: t.Optional[t.Tuple[int, int]] = None, + run_on_obs: bool = False, + ) -> Rollout: + if run_on_obs and obs_res is None: + raise RuntimeError("Run on pixels cannot be done without specified resolution") + s, a, r, n, f, o = [], [], [], [], [], [] + # TODO: worth creating N+1 standard of env, which will incorporate + # gym/dm_control/etc to remove just bloated matching all over project match env: case gym.Env(): state, _ = env.reset() + if run_on_obs is True: + raise RuntimeError("Run on pixels currently supported only for dm_control") case dmEnv(): state, _, terminated = decode_dm_ts(env.reset()) + if run_on_obs is True: + state = env.physics.render(*obs_res, camera_id=0) if agent is None: agent = RandomAgent(env) @@ -30,6 +44,8 @@ def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, obs_r new_state, reward, terminated, _, _ = env.step(action) case dmEnv(): new_state, reward, terminated = decode_dm_ts(env.step(action)) + if run_on_obs: + new_state = env.physics.render(*obs_res, camera_id=0) s.append(state) # FIXME: action discritezer should be defined once @@ -50,15 +66,23 @@ def collect_rollout(env: gym.Env | dmEnv, agent: t.Optional[t.Any] = None, obs_r obs = np.array(o) if obs_res is not None else None return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) -def collect_rollout_num(env: gym.Env, num: int, agent: t.Optional[t.Any] = None, obs_res: bool = False) -> t.List[Rollout]: +def collect_rollout_num(env: gym.Env, + num: int, + agent: t.Optional[t.Any] = None, + obs_res: t.Optional[t.Tuple[int, int]] = None, + run_on_obs: bool = False) -> t.List[Rollout]: # TODO: paralelyze rollouts = [] for _ in range(num): - rollouts.append(collect_rollout(env, agent, obs_res)) + rollouts.append(collect_rollout(env, agent, obs_res, run_on_obs)) return rollouts -def fillup_replay_buffer(env: gym.Env, rep_buffer: ReplayBuffer, num: int, obs_res: t.Optional[t.Tuple[int, int]] = None): +def fillup_replay_buffer(env: gym.Env, + rep_buffer: ReplayBuffer, + num: int, + obs_res: t.Optional[t.Tuple[int, int]] = None, + run_on_obs: bool = False): # TODO: paralelyze while not rep_buffer.can_sample(num): - rep_buffer.add_rollout(collect_rollout(env, obs_res=obs_res)) + rep_buffer.add_rollout(collect_rollout(env, obs_res=obs_res, run_on_obs=run_on_obs)) From 4c3f24910b03299fdb356435944cbc63c6a48750 Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 10 Nov 2022 19:30:12 +0000 Subject: [PATCH 010/106] Implemented ActorCritic training in Latent Space --- config/agent/dreamer_v2.yaml | 12 +- config/config.yaml | 1 + main.py | 13 +- rl_sandbox/agents/dreamer_v2.py | 302 ++++++++++++++++++++------------ rl_sandbox/agents/rl_agent.py | 5 + rl_sandbox/utils/fc_nn.py | 3 +- 6 files changed, 219 insertions(+), 117 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index ed44c63..7682452 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -7,6 +7,14 @@ rssm_dim: 200 kl_loss_scale: 0.1 # ActorCritic parameters -imagination_horizon: 15 discount_factor: 0.995 -#critic_update_interval: 100 +imagination_horizon: 15 +# Lambda parameter for trainin deeper multi-step prediction +critic_value_target_lambda: 0.95 +critic_update_interval: 100 +# [0-1], 1 means hard update +critic_soft_update_fraction: 1 +# mixing of reinforce and maximizing value func +# for dm_control it is zero in Dreamer (Atari 1) +actor_reinforce_fraction: 0 +actor_entropy_scale: 1e-4 diff --git a/config/config.yaml b/config/config.yaml index fb38c36..f9cbc1a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -10,6 +10,7 @@ env: task_name: swingup run_on_pixels: true obs_res: [64, 64] + action_discrete_num: 10 seed: 42 device_type: cpu diff --git a/main.py b/main.py index 1750e44..570e268 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,6 @@ import numpy as np from dm_control import suite from omegaconf import DictConfig, OmegaConf -import numpy as np from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm @@ -33,11 +32,13 @@ def main(cfg: DictConfig): case _: raise RuntimeError("Invalid environment type") + # TODO: As images take much more data, rewrite replay buffer to be + # more memory efficient buff = ReplayBuffer() obs_res = cfg.env.obs_res if cfg.env.run_on_pixels else None fillup_replay_buffer(env, buff, cfg.training.batch_size, obs_res=obs_res, run_on_obs=cfg.env.run_on_pixels) - action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=10) + action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=cfg.env.action_discrete_num) metrics_evaluator = MetricsEvaluator() match cfg.env.type: @@ -49,10 +50,9 @@ def main(cfg: DictConfig): obs_space_num = env.observation_space.shape[0] exploration_agent = RandomAgent(env) - # FIXME: currently action is 1 value, but not one-hot encoding agent = hydra.utils.instantiate(cfg.agent, obs_space_num=obs_space_num, - actions_num=(1), + actions_num=(cfg.env.action_discrete_num), device_type=cfg.device_type) writer = SummaryWriter() @@ -69,14 +69,16 @@ def main(cfg: DictConfig): obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None case "gym": state, _ = env.reset() + agent.reset() terminated = False while not terminated: + # TODO: For dreamer, add noise for sampling if np.random.random() > scheduler.step(): action = exploration_agent.get_action(state) action = action_disritizer.discretize(action) else: - action = agent.get_action(state) + action = agent.get_action(obs if cfg.env.run_on_pixels else state) match cfg.env.type: case "dm_control": @@ -98,6 +100,7 @@ def main(cfg: DictConfig): return_observation=cfg.env.run_on_pixels, cluster_size=cfg.agent.get('batch_cluster_size', 1)) + # NOTE: Dreamer makes 4 policy steps per gradient descent losses = agent.train(s, a, r, n, f) if isinstance(losses, np.ndarray): writer.add_scalar('train/loss', loss, global_step) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index f27e52a..5204979 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -9,8 +9,9 @@ from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator -from rl_sandbox.utils.replay_buffer import (Action, Actions, Observations, - Rewards, State, TerminationFlags) +from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, + Observations, Rewards, + TerminationFlags) class View(nn.Module): @@ -42,7 +43,7 @@ class Quantize(nn.Module): def forward(self, logits): dist = torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( logits=logits) - return dist.rsample() + return dist class RSSM(nn.Module): @@ -55,12 +56,12 @@ class RSSM(nn.Module): s_t <- stohastic discrete posterior state (latent representation of current state) h_1 ---> h_2 ---> h_3 ---> - \ x_1 \ x_2 \ x_3 - | \ | ^ | \ | ^ | \ | ^ + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ v MLP CNN | v MLP CNN | v MLP CNN | - \ | | \ | | \ | | - Ensemble \ | | Ensemble \ | | Ensemble \ | | - \| | \| | \| | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | | | | | | | | | | v v | v v | v v | | | | @@ -68,16 +69,16 @@ class RSSM(nn.Module): """ - def __init__(self, latent_dim, hidden_size, actions_num, categories_num): + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): super().__init__() self.latent_dim = latent_dim - self.categories_num = categories_num + self.latent_classes = latent_classes self.ensemble_num = 5 self.hidden_size = hidden_size # Calculate deterministic state from prev stochastic, prev action and prev deterministic self.pre_determ_recurrent = nn.Sequential( - nn.Linear(latent_dim * categories_num + actions_num, + nn.Linear(latent_dim * latent_classes + actions_num, hidden_size), # Dreamer 'img_in' nn.LayerNorm(hidden_size), ) @@ -91,13 +92,13 @@ def __init__(self, latent_dim, hidden_size, actions_num, categories_num): nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' nn.LayerNorm(hidden_size), nn.Linear(hidden_size, - latent_dim * self.categories_num), # Dreamer 'img_dist_{k}' - View((-1, latent_dim, self.categories_num))) - for _ in range(self.ensemble_num) + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((-1, latent_dim, self.latent_classes)), + Quantize()) for _ in range(self.ensemble_num) ]) # For observation we do not have ensemble - # FIXME: very band magic number + # FIXME: very bad magic number img_sz = 4 * 384 # 384*2x2 self.stoch_net = nn.Sequential( nn.Linear(hidden_size + img_sz, hidden_size), @@ -105,21 +106,18 @@ def __init__(self, latent_dim, hidden_size, actions_num, categories_num): nn.Linear(hidden_size, hidden_size), # Dreamer 'obs_out' nn.LayerNorm(hidden_size), nn.Linear(hidden_size, - latent_dim * self.categories_num), # Dreamer 'obs_dist' - View((-1, latent_dim, self.categories_num)), + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((-1, latent_dim, self.latent_classes)), # NOTE: Maybe worth having some LogSoftMax as activation # before using input as logits for distribution - # Quantize() - ) + Quantize()) def estimate_stochastic_latent(self, prev_determ): - logits_per_model = torch.stack( - [model(prev_determ) for model in self.ensemble_prior_estimator]) + dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of # taking only one random between all ensembles - index = torch.randint(0, self.ensemble_num, ()) - return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( - logits=logits_per_model[index]) + idx = torch.randint(0, self.ensemble_num, ()) + return dists_per_model[idx] def predict_next(self, stoch_latent, @@ -135,7 +133,11 @@ def predict_next(self, predicted_stoch_latent = self.estimate_stochastic_latent(determ) return deter_state, predicted_stoch_latent - def forward(self, h_prev: tuple[torch.Tensor, torch.Tensor], embed, action): + def update_current(self, determ, embed): # Dreamer 'obs_out' + return self.stoch_net(torch.concat([determ, embed], dim=2)) + + def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, + action): """ 'h' <- internal state of the world 'z' <- latent embedding of current observation @@ -144,15 +146,15 @@ def forward(self, h_prev: tuple[torch.Tensor, torch.Tensor], embed, action): """ # Use zero vector for prev_state of first + if h_prev is None: + h_prev = (torch.zeros((*action.shape[:-1], self.hidden_size)), + torch.zeros( + (*action.shape[:-1], self.latent_dim * self.latent_classes))) deter_prev, stoch_prev = h_prev determ, prior_stoch_dist = self.predict_next(stoch_prev, action, deter_state=deter_prev) - - posterior_stoch_logits = self.stoch_net(torch.concat([determ, embed], - dim=2)) # Dreamer 'obs_out' - posterior_stoch_dist = torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( - logits=posterior_stoch_logits) + posterior_stoch_dist = self.update_current(determ, embed) return [determ, prior_stoch_dist, posterior_stoch_dist] @@ -226,17 +228,19 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.recurrent_model = RSSM(latent_dim, rssm_dim, actions_num, - categories_num=latent_classes) + latent_classes=latent_classes) self.encoder = Encoder() self.image_predictor = Decoder() self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, - num_layers=4) + num_layers=4, + intermediate_activation=nn.ELU) self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, num_layers=4, + intermediate_activation=nn.ELU, final_activation=nn.Sigmoid) self.optimizer = torch.optim.Adam(itertools.chain( @@ -249,16 +253,25 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor determ_state, next_repr_dist = self.recurrent_model.predict_next( latent_repr.unsqueeze(0), action.unsqueeze(0), world_state) - next_repr = next_repr_dist.sample().reshape(-1, self.latent_dim * self.latent_classes) - reward = self.reward_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)) - is_finished = self.discount_predictor( torch.concat([determ_state.squeeze(0), next_repr], dim=1)) + next_repr = next_repr_dist.rsample().reshape( + -1, self.latent_dim * self.latent_classes) + reward = self.reward_predictor( + torch.concat([determ_state.squeeze(0), next_repr], dim=1)) + is_finished = self.discount_predictor( + torch.concat([determ_state.squeeze(0), next_repr], dim=1)) return determ_state, next_repr, reward, is_finished + def get_latent(self, obs: torch.Tensor, state): + embed = self.encoder(obs) + determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), + self._last_action) + latent_repr = latent_repr_dist.rsample().reshape(-1, 32 * 32) + return determ, latent_repr.unsqueeze(0) + def train(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, is_finished: torch.Tensor): b, h, w, _ = obs.shape # s <- BxHxWx3 - obs = ((obs.type(torch.float32) / 255.0) - 0.5).permute(0, 3, 1, 2) embed = self.encoder(obs) embed = embed.view(b // self.cluster_size, self.cluster_size, -1) @@ -267,12 +280,8 @@ def train(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, r = r.view(-1, self.cluster_size, 1) f = is_finished.view(-1, self.cluster_size, 1) - h_prev = [ - torch.zeros((1, b // self.cluster_size, self.rssm_dim)), - torch.zeros( - (1, b // self.cluster_size, self.latent_dim * self.latent_classes)) - ] - losses = defaultdict(lambda: torch.zeros(1)) + h_prev = None + losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence @@ -317,112 +326,187 @@ def KL(dist1, dist2): loss.backward() self.optimizer.step() - return {l: val.detach() for l, val in losses.items()}, torch.stack(latent_vars) + discovered_latents = torch.stack(latent_vars).reshape( + -1, self.latent_dim * self.latent_classes) + return {l: val.detach() for l, val in losses.items()}, discovered_latents -class ImaginativeActorCritic(nn.Module): +class ImaginativeCritic(nn.Module): - def __init__(self): + def __init__(self, discount_factor: float, update_interval: int, + soft_update_fraction: float, value_target_lambda: float, latent_dim: int, + actions_num: int): super().__init__() - # mixing of reinforce and maximizing value func - self.rho = 0 # for dm_control it is zero in Dreamer (Atari 1) - # scale for entropy - self.eta = 1e-4 - # parameter for n-step return for value function - self.lambda_par = 0.95 - - # NOTE: stochastic - # MLP with ELU activation - # Output layer categorical distribution - self.actor = ... - - # NOTE: deterministic - # MLP with ELU activation - # Uses Q-network and Actor to chose actions - self.critic = ... - self.target_actor = ... - - def get_action(self, z): - b, _ = z.shape - return torch.ones((b, 1)) + self.gamma = discount_factor + self.critic_update_interval = update_interval + self.lambda_ = value_target_lambda + self.critic_soft_update_fraction = soft_update_fraction + self._update_num = 0 + + self.critic = fc_nn_generator(latent_dim, + actions_num, + 400, + 1, + intermediate_activation=nn.ELU) + self.target_critic = fc_nn_generator(latent_dim, + actions_num, + 400, + 1, + intermediate_activation=nn.ELU) + + def update_target(self): + if self._update_num == 0: + for target_param, local_param in zip(self.target_critic.parameters(), + self.critic.parameters()): + mix = self.critic_soft_update_fraction + target_param.data.copy_(mix * local_param.data + + (1 - mix) * target_param.data) + self._update_num = (self._update_num + 1) % self.critic_update_interval def estimate_value(self, z) -> torch.Tensor: - ... - - # NOTE: use target network for calculating - # labmda target - def lambda_return(self): - ... + return self.critic(z) - def calculate_loss(self): - # Critic <- Bellman return ( (V - V_lambda)^2 ) - # Bellman return for critic - # Actor <- reinforce + maximize value_func + maximize actor entropy - ... + def lambda_return(self, zs, rs, ts): + v_lambdas = [self.target_critic(zs[-1])] + for r, z, t in zip(reversed(rs[:-1]), reversed(zs[:-1]), reversed(ts[:-1])): + v_lambda = r + t * self.gamma * ( + (1 - self.lambda_) * self.target_critic(z) + self.lambda_ * v_lambdas[-1]) + v_lambdas.append(v_lambda) + return torch.concat(list(reversed(v_lambdas)), dim=0) class DreamerV2(RlAgent): - def __init__(self, - obs_space_num: int, - actions_num: int, - batch_cluster_size: int, - latent_dim: int, - latent_classes: int, - rssm_dim: int, - discount_factor: float, - kl_loss_scale: float, - imagination_horizon: int, - device_type: str = 'cpu'): - + def __init__( + self, + obs_space_num: int, # NOTE: encoder/decoder will work only with 64x64 currently + actions_num: int, + batch_cluster_size: int, + latent_dim: int, + latent_classes: int, + rssm_dim: int, + discount_factor: float, + kl_loss_scale: float, + imagination_horizon: int, + critic_update_interval: int, + actor_reinforce_fraction: float, + actor_entropy_scale: float, + critic_soft_update_fraction: float, + critic_value_target_lambda: float, + device_type: str = 'cpu'): + + self._state = None + self._last_action = torch.zeros(actions_num) + self.actions_num = actions_num self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size self.actions_num = actions_num - self.H = imagination_horizon - self.gamma = discount_factor + self.rho = actor_reinforce_fraction + if actor_reinforce_fraction != 0: + raise NotImplementedError("Reinforce part is not implemented") + self.eta = actor_entropy_scale self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale).to(device_type) - self.actor_critic = ImaginativeActorCritic() + self.actor = fc_nn_generator(latent_dim, + actions_num, + 400, + 4, + intermediate_activation=nn.ELU, + final_activation=Quantize) + # TODO: Leave only ImaginativeCritic and move Actor to DreamerV2 + self.critic = ImaginativeCritic(discount_factor, critic_update_interval, + critic_soft_update_fraction, + critic_value_target_lambda, + latent_dim * latent_classes, actions_num) + + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=4e-5) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-4) def imagine_trajectory( - self, - z_0) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + self, z_0 + ) -> list[tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, + torch.Tensor]]: rollout = [] world_state = None - z = z_0 + z = z_0.detach().unsqueeze(0) for _ in range(self.imagination_horizon): - a = self.actor_critic.get_action(z) + a = self.actor(z) world_state, next_z, reward, is_finished = self.world_model.predict_next( - z, a, world_state) - rollout.append((z, a, next_z, reward, is_finished)) - z = next_z + z, a.rsample(), world_state) + rollout.append( + (z.detach(), a, next_z.detach(), reward.detach(), is_finished.detach())) + z = next_z.detach() return rollout - def get_action(self, obs: State) -> Action: - return self.actions_num + def reset(self): + self._state = None + self._last_action = torch.zeros((1, 1, self.actions_num)) + + def preprocess_obs(self, obs: torch.Tensor): + order = list(range(obs.shape)) + # Swap channel from last to 3 from last + order = order[:-3] + [order[-1]] + [order[-3:-1]] + return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) + + def get_action(self, obs: Observation) -> Action: + # NOTE: pytorch fails without .copy() only when get_action is called + obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) + obs = self.preprocess_obs(obs) + + self._state = self.world_model.get_latent(obs, self._state) + self._last_action = self.actor(self._state[1]).rsample().unsqueeze(0) + + return self._last_action.squeeze().detach().cpu().numpy().argmax() def from_np(self, arr: np.ndarray): return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) - def train(self, s: Observations, a: Actions, r: Rewards, next: Observations, + def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observations, is_finished: TerminationFlags): - s = self.from_np(s) + obs = self.preprocess_obs(self.from_np(obs)) a = self.from_np(a) + a = F.one_hot(a, num_classes=self.actions_num).squeeze() r = self.from_np(r) - next = self.from_np(next) + next_obs = self.from_np(next_obs) is_finished = self.from_np(is_finished) # take some latent embeddings as initial step - losses, discovered_latents = self.world_model.train(next, a, r, is_finished) + losses, discovered_latents = self.world_model.train(next_obs, a, r, is_finished) - perm = torch.randperm(discovered_latents.size(0)) - idx = perm[:12] + idx = torch.randperm(discovered_latents.size(0)) initial_states = discovered_latents[idx] + losses_ac = defaultdict( + lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) + for z_0 in initial_states: rollout = self.imagine_trajectory(z_0) - - return losses + zs, action_dists, next_zs, rewards, terminal_flags = zip(*rollout) + vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) + + losses_ac['loss_critic'] += F.mse_loss(self.critic.estimate_value( + torch.stack(next_zs).squeeze(1)), + vs.detach(), + reduction='sum') + + losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control + losses_ac['loss_actor_dynamics_backprop'] += (-(1 - self.rho) * vs[-1]).mean() + losses_ac['loss_actor_entropy'] += -self.eta * torch.stack( + [a.entropy() for a in action_dists[:-1]]).mean() + losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ + 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + + self.actor_optimizer.zero_grad() + self.critic_optimizer.zero_grad() + losses_ac['loss_critic'].backward() + losses_ac['loss_actor'].backward() + self.actor_optimizer.step() + self.critic_optimizer.step() + self.critic.update_target() + + losses_ac = {l: val.detach() for l, val in losses_ac.items()} + + return losses | losses_ac diff --git a/rl_sandbox/agents/rl_agent.py b/rl_sandbox/agents/rl_agent.py index 97c4fa7..b895191 100644 --- a/rl_sandbox/agents/rl_agent.py +++ b/rl_sandbox/agents/rl_agent.py @@ -10,3 +10,8 @@ def get_action(self, obs: State) -> Action: @abstractmethod def train(self, s: States, a: Actions, r: Rewards, next: States): pass + + # Some models can have internal state which should be + # properly reseted between rollouts + def reset(self): + pass diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index a66d082..f00edb8 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -5,13 +5,14 @@ def fc_nn_generator(input_num: int, output_num: int, hidden_size: int, num_layers: int, + intermediate_activation: t.Type[nn.Module] = nn.ReLU, final_activation: t.Type[nn.Module] = nn.Identity): layers = [] layers.append(nn.Linear(input_num, hidden_size)) layers.append(nn.ReLU(inplace=True)) for _ in range(num_layers): layers.append(nn.Linear(hidden_size, hidden_size)) - layers.append(nn.ReLU(inplace=True)) + layers.append(intermediate_activation(inplace=True)) layers.append(nn.Linear(hidden_size, output_num)) layers.append(final_activation()) return nn.Sequential(*layers) From a8089b95688ab8edded6d315ecfd6cbea69b74b3 Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 14 Nov 2022 17:33:46 +0000 Subject: [PATCH 011/106] Added Env abstraction for different env support, and action transform --- config/config.yaml | 11 +- config/env/dm_cartpole.yaml | 9 ++ main.py | 73 ++------- pyproject.toml | 1 + rl_sandbox/agents/dqn.py | 2 +- rl_sandbox/agents/dreamer_v2.py | 29 ++-- rl_sandbox/agents/random_agent.py | 20 +-- rl_sandbox/agents/rl_agent.py | 6 +- rl_sandbox/utils/dm_control.py | 11 -- rl_sandbox/utils/env.py | 196 +++++++++++++++++++++++++ rl_sandbox/utils/replay_buffer.py | 6 +- rl_sandbox/utils/rollout_generation.py | 68 +++------ 12 files changed, 269 insertions(+), 163 deletions(-) create mode 100644 config/env/dm_cartpole.yaml create mode 100644 rl_sandbox/utils/env.py diff --git a/config/config.yaml b/config/config.yaml index f9cbc1a..3aa535f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,17 +1,8 @@ defaults: - agent/dreamer_v2 + - env/dm_cartpole - _self_ -env: - # type: gym - # name: CartPole-v1 - type: dm_control - domain_name: cartpole - task_name: swingup - run_on_pixels: true - obs_res: [64, 64] - action_discrete_num: 10 - seed: 42 device_type: cpu diff --git a/config/env/dm_cartpole.yaml b/config/env/dm_cartpole.yaml new file mode 100644 index 0000000..14a1a59 --- /dev/null +++ b/config/env/dm_cartpole.yaml @@ -0,0 +1,9 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: cartpole +task_name: swingup +run_on_pixels: true +obs_res: [64, 64] +transforms: + - _target_: rl_sandbox.utils.env.ActionNormalizer + - _target_: rl_sandbox.utils.env.ActionDisritezer + actions_num: 10 diff --git a/main.py b/main.py index 570e268..6dbbec5 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,14 @@ -import gym import hydra import numpy as np -from dm_control import suite from omegaconf import DictConfig, OmegaConf from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm +from unpackable import unpack from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.metrics import MetricsEvaluator -from rl_sandbox.utils.dm_control import ActionDiscritizer, decode_dm_ts +from rl_sandbox.utils.dm_control import ActionDiscritizer +from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer from rl_sandbox.utils.rollout_generation import (collect_rollout_num, fillup_replay_buffer) @@ -19,40 +19,19 @@ def main(cfg: DictConfig): # print(OmegaConf.to_yaml(cfg)) - match cfg.env.type: - case "dm_control": - env = suite.load(domain_name=cfg.env.domain_name, - task_name=cfg.env.task_name) - visualized_env = env - case "gym": - env = gym.make(cfg.env) - visualized_env = gym.make(cfg.env, render_mode='rgb_array_list') - if cfg.env.run_on_pixels: - raise NotImplementedError("Run on pixels supported only for 'dm_control'") - case _: - raise RuntimeError("Invalid environment type") + env: Env = hydra.utils.instantiate(cfg.env) # TODO: As images take much more data, rewrite replay buffer to be # more memory efficient buff = ReplayBuffer() - obs_res = cfg.env.obs_res if cfg.env.run_on_pixels else None - fillup_replay_buffer(env, buff, cfg.training.batch_size, obs_res=obs_res, run_on_obs=cfg.env.run_on_pixels) + fillup_replay_buffer(env, buff, cfg.training.batch_size) - action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=cfg.env.action_discrete_num) metrics_evaluator = MetricsEvaluator() - match cfg.env.type: - case "dm_control": - obs_space_num = sum([v.shape[0] for v in env.observation_spec().values()]) - if cfg.env.run_on_pixels: - obs_space_num = (*cfg.env.obs_res, 3) - case "gym": - obs_space_num = env.observation_space.shape[0] - exploration_agent = RandomAgent(env) agent = hydra.utils.instantiate(cfg.agent, - obs_space_num=obs_space_num, - actions_num=(cfg.env.action_discrete_num), + obs_space_num=env.observation_space.shape[0], + actions_num=env.action_space.shape[0], device_type=cfg.device_type) writer = SummaryWriter() @@ -63,12 +42,7 @@ def main(cfg: DictConfig): for epoch_num in tqdm(range(cfg.training.epochs)): ### Training and exploration - match cfg.env.type: - case "dm_control": - state, _, _ = decode_dm_ts(env.reset()) - obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None - case "gym": - state, _ = env.reset() + state, _, _ = unpack(env.reset()) agent.reset() terminated = False @@ -76,39 +50,22 @@ def main(cfg: DictConfig): # TODO: For dreamer, add noise for sampling if np.random.random() > scheduler.step(): action = exploration_agent.get_action(state) - action = action_disritizer.discretize(action) - else: - action = agent.get_action(obs if cfg.env.run_on_pixels else state) - - match cfg.env.type: - case "dm_control": - new_state, reward, terminated = decode_dm_ts(env.step(action_disritizer.undiscretize(action))) - new_obs = env.physics.render(*cfg.env.obs_res, camera_id=0) if cfg.env.run_on_pixels else None - case "gym": - new_state, reward, terminated, _, _ = env.step(action) - action = action_disritizer.undiscretize(action) - obs = None - - if cfg.env.run_on_pixels: - buff.add_sample(obs, action, reward, new_obs, terminated, obs) else: - buff.add_sample(state, action, reward, new_state, terminated, obs) + action = agent.get_action(state) + + new_state, reward, terminated = unpack(env.step(action)) + + buff.add_sample(state, action, reward, new_state, terminated) # NOTE: unintuitive that batch_size is now number of total # samples, but not amount of sequences for recurrent model s, a, r, n, f = buff.sample(cfg.training.batch_size, - return_observation=cfg.env.run_on_pixels, cluster_size=cfg.agent.get('batch_cluster_size', 1)) # NOTE: Dreamer makes 4 policy steps per gradient descent losses = agent.train(s, a, r, n, f) - if isinstance(losses, np.ndarray): - writer.add_scalar('train/loss', loss, global_step) - elif isinstance(losses, dict): - for loss_name, loss in losses.items(): - writer.add_scalar(f'train/{loss_name}', loss, global_step) - else: - raise RuntimeError("AAAA, very bad") + for loss_name, loss in losses.items(): + writer.add_scalar(f'train/{loss_name}', loss, global_step) global_step += 1 ### Validation diff --git a/pyproject.toml b/pyproject.toml index 4484501..dbd22d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,3 +24,4 @@ torchvision = '^0.13' torch = '^1.12' tensorboard = '^2.0' dm-control = '^1.0.0' +unpackable = '^0.0.4' diff --git a/rl_sandbox/agents/dqn.py b/rl_sandbox/agents/dqn.py index 9c7b01f..b2fb3b0 100644 --- a/rl_sandbox/agents/dqn.py +++ b/rl_sandbox/agents/dqn.py @@ -49,4 +49,4 @@ def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: Te loss.backward() self.optimizer.step() - return loss.detach().cpu() + return {'loss': loss.detach().cpu()} diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 5204979..764f8de 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -147,9 +147,7 @@ def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, # Use zero vector for prev_state of first if h_prev is None: - h_prev = (torch.zeros((*action.shape[:-1], self.hidden_size)), - torch.zeros( - (*action.shape[:-1], self.latent_dim * self.latent_classes))) + h_prev = (torch.zeros((*action.shape[:-1], self.hidden_size)), torch.zeros((*action.shape[:-1], self.latent_dim * self.latent_classes))) deter_prev, stoch_prev = h_prev determ, prior_stoch_dist = self.predict_next(stoch_prev, action, @@ -261,22 +259,21 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor torch.concat([determ_state.squeeze(0), next_repr], dim=1)) return determ_state, next_repr, reward, is_finished - def get_latent(self, obs: torch.Tensor, state): + def get_latent(self, obs: torch.Tensor, action, state): embed = self.encoder(obs) - determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), - self._last_action) + determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), action) latent_repr = latent_repr_dist.rsample().reshape(-1, 32 * 32) return determ, latent_repr.unsqueeze(0) def train(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, is_finished: torch.Tensor): - b, h, w, _ = obs.shape # s <- BxHxWx3 + b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) embed = embed.view(b // self.cluster_size, self.cluster_size, -1) obs = obs.view(-1, self.cluster_size, 3, h, w) - a = a.view(-1, self.cluster_size, a.shape[1]) + a = a.view(-1, self.cluster_size, 1) r = r.view(-1, self.cluster_size, 1) f = is_finished.view(-1, self.cluster_size, 1) @@ -409,7 +406,7 @@ def __init__( self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale).to(device_type) - self.actor = fc_nn_generator(latent_dim, + self.actor = fc_nn_generator(latent_dim * latent_classes, actions_num, 400, 4, @@ -445,20 +442,20 @@ def reset(self): self._last_action = torch.zeros((1, 1, self.actions_num)) def preprocess_obs(self, obs: torch.Tensor): - order = list(range(obs.shape)) + order = list(range(len(obs.shape))) # Swap channel from last to 3 from last - order = order[:-3] + [order[-1]] + [order[-3:-1]] + order = order[:-3] + [order[-1]] + order[-3:-1] return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) def get_action(self, obs: Observation) -> Action: # NOTE: pytorch fails without .copy() only when get_action is called obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) - obs = self.preprocess_obs(obs) + obs = self.preprocess_obs(obs).unsqueeze(0) - self._state = self.world_model.get_latent(obs, self._state) - self._last_action = self.actor(self._state[1]).rsample().unsqueeze(0) + self._state = self.world_model.get_latent(obs, self._last_action, self._state) + self._last_action = self.actor(self._state[1]).rsample() - return self._last_action.squeeze().detach().cpu().numpy().argmax() + return np.array([self._last_action.squeeze().detach().cpu().numpy().argmax()]) def from_np(self, arr: np.ndarray): return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) @@ -470,7 +467,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation a = self.from_np(a) a = F.one_hot(a, num_classes=self.actions_num).squeeze() r = self.from_np(r) - next_obs = self.from_np(next_obs) + next_obs = self.preprocess_obs(self.from_np(next_obs)) is_finished = self.from_np(is_finished) # take some latent embeddings as initial step diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py index d06e031..b1dffb4 100644 --- a/rl_sandbox/agents/random_agent.py +++ b/rl_sandbox/agents/random_agent.py @@ -1,30 +1,18 @@ -import gym import numpy as np -from dm_control.composer.environment import Environment as dmEnv from nptyping import Float, NDArray, Shape from rl_sandbox.agents.rl_agent import RlAgent -from rl_sandbox.utils.dm_control import ActionDiscritizer +from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, States, TerminationFlags) class RandomAgent(RlAgent): - def __init__(self, env: gym.Env | dmEnv): - self.action_space = None - self.action_spec = None - if isinstance(env, gym.Env): - self.action_space = env.action_space - else: - self.action_spec = env.action_spec() + def __init__(self, env: Env): + self.action_space = env.action_space def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: - if self.action_space is not None: - return self.action_space.sample() - else: - return np.random.uniform(self.action_spec.minimum, - self.action_spec.maximum, - size=self.action_spec.shape) + return self.action_space.sample() def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): pass diff --git a/rl_sandbox/agents/rl_agent.py b/rl_sandbox/agents/rl_agent.py index b895191..b159db3 100644 --- a/rl_sandbox/agents/rl_agent.py +++ b/rl_sandbox/agents/rl_agent.py @@ -1,3 +1,4 @@ +from typing import Any from abc import ABCMeta, abstractmethod from rl_sandbox.utils.replay_buffer import Action, State, States, Actions, Rewards @@ -8,7 +9,10 @@ def get_action(self, obs: State) -> Action: pass @abstractmethod - def train(self, s: States, a: Actions, r: Rewards, next: States): + def train(self, s: States, a: Actions, r: Rewards, next: States) -> dict[str, Any]: + """ + Return dict with losses for logging + """ pass # Some models can have internal state which should be diff --git a/rl_sandbox/utils/dm_control.py b/rl_sandbox/utils/dm_control.py index f31e64d..92bc6f6 100644 --- a/rl_sandbox/utils/dm_control.py +++ b/rl_sandbox/utils/dm_control.py @@ -37,14 +37,3 @@ def undiscretize(self, action: NDArray[Shape['*'], Int]) -> NDArray[Shape['*'], for k, vals in zip(reversed(ks), self.grid): a.append(vals[k]) return np.array(a) - -def decode_dm_ts(time_step): - state = time_step.observation - state = np.concatenate([state[s] for s in state], dtype=np.float32) - reward = time_step.reward - terminated = time_step.last() - # if time_step.discount is not None: - # terminated = not time_step.discount - # else: - # terminated = False - return state, reward, terminated diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py new file mode 100644 index 0000000..200ecf9 --- /dev/null +++ b/rl_sandbox/utils/env.py @@ -0,0 +1,196 @@ +import typing as t +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass + +import gym +import numpy as np +from dm_control import suite +from dm_env import Environment as dmEnviron +from dm_env import TimeStep +from nptyping import Float, Int, NDArray, Shape + +Observation = NDArray[Shape["*,*,3"], Int] +State = NDArray[Shape["*"], Float] +Action = NDArray[Shape["*"], Int] + + +@dataclass +class EnvStepResult: + obs: Observation | State + reward: float + terminated: bool + + +class ActionTransformer(metaclass=ABCMeta): + + def set_env(self, env: 'Env'): + self.low = env.action_space.low + self.high = env.action_space.high + + @abstractmethod + def transform_action(self, action): + ... + + @abstractmethod + def transform_space(self, space: gym.spaces.Box): + ... + + +class ActionNormalizer(ActionTransformer): + + def set_env(self, env: 'Env'): + super().set_env(env) + if (~np.isfinite(self.low) | ~np.isfinite(self.high)).any(): + raise RuntimeError("Not bounded space cannot be normalized") + + def transform_action(self, action): + return (self.high - self.low) * (action + 1) / 2 + self.low + + def transform_space(self, space: gym.spaces.Box): + return gym.spaces.Box(-np.ones_like(self.low), + np.ones_like(self.high), + dtype=np.float32) + + +class ActionDisritezer(ActionTransformer): + + def __init__(self, actions_num: int): + self.per_dim = actions_num + + def set_env(self, env: 'Env'): + super().set_env(env) + if (~np.isfinite(self.low) | ~np.isfinite(self.high)).any(): + raise RuntimeError("Not bounded space cannot be discritized") + + self.grid = np.stack([ + np.linspace(min, max, self.per_dim, endpoint=True) + for min, max in zip(self.low, self.high) + ]) + + def transform_action(self, action: NDArray[Shape['*'], + Int]) -> NDArray[Shape['*'], Float]: + ks = [] + k = action + for i in range(self.per_dim - 1, -1, -1): + ks.append(k // self.per_dim**i) + k -= ks[-1] * self.per_dim**i + + a = [] + for k, vals in zip(reversed(ks), self.grid): + a.append(vals[k]) + return np.array(a) + + def transform_space(self, space: gym.spaces.Box): + return gym.spaces.Box(0, self.per_dim**len(self.low), dtype=np.uint32) + + +class Env(metaclass=ABCMeta): + + def __init__(self, run_on_pixels: bool, obs_res: tuple[int, int], + transforms: list[ActionTransformer]): + self.obs_res = obs_res + self.run_on_pixels = run_on_pixels + self.ac_trans = [] + for t in transforms: + t.set_env(self) + self.ac_trans.append(t) + + def step(self, action: Action) -> EnvStepResult: + for t in reversed(self.ac_trans): + action = t.transform_action(action) + return self._step(action) + + @abstractmethod + def _step(self, action: Action) -> EnvStepResult: + pass + + @abstractmethod + def reset(self) -> EnvStepResult: + pass + + @abstractmethod + def _observation_space(self) -> gym.Space: + pass + + @abstractmethod + def _action_space(self) -> gym.Space: + ... + + @property + def observation_space(self) -> gym.Space: + return self._observation_space() + + @property + def action_space(self) -> gym.Space: + space = self._action_space() + for t in self.ac_trans: + space = t.transform_space(t) + return space + + +class GymEnv(Env): + + def __init__(self, task_name: str, run_on_pixels: bool, obs_res: tuple[int, int], + transforms: list[ActionTransformer]): + super().__init__(run_on_pixels, obs_res, transforms) + + self.env: gym.Env = gym.make(task_name) + self.visualized_env: gym.Env = gym.make(task_name, render_mode='rgb_array_list') + + if run_on_pixels: + raise NotImplementedError("Run on pixels supported only for 'dm_control'") + + def _step(self, action: Action) -> EnvStepResult: + new_state, reward, terminated, _, _ = self.env.step(action) + return EnvStepResult(new_state, reward, terminated) + + def reset(self): + state, _ = self.env.reset() + return EnvStepResult(state, 0, False) + + @property + def _observation_space(self): + return self.env.observation_space + + def _action_space(self): + return self.env.action_space + + +class DmEnv(Env): + + def __init__(self, run_on_pixels: bool, obs_res: tuple[int, int], domain_name: str, + task_name: str, + transforms: list[ActionTransformer]): + self.env: dmEnviron = suite.load(domain_name=domain_name, task_name=task_name) + super().__init__(run_on_pixels, obs_res, transforms) + + def render(self): + return self.env.physics.render(*self.obs_res, camera_id=0) + + def _uncode_ts(self, ts: TimeStep) -> EnvStepResult: + if self.run_on_pixels: + state = self.render() + else: + state = ts.observation + state = np.concatenate([state[s] for s in state], dtype=np.float32) + return EnvStepResult(state, ts.reward, ts.last()) + + def _step(self, action: Action) -> EnvStepResult: + # TODO: add action repeat to speed up DMC simulations + return self._uncode_ts(self.env.step(action)) + + def reset(self) -> EnvStepResult: + return self._uncode_ts(self.env.reset()) + + def _observation_space(self): + if self.run_on_pixels: + return gym.spaces.Box(0, 255, self.obs_res + (3, ), dtype=np.uint8) + else: + raise NotImplementedError( + "Currently run on pixels is supported for 'dm_control'") + # for space in self.env.observation_spec(): + # obs_space_num = sum([v.shape[0] for v in env.observation_space().values()]) + + def _action_space(self): + spec = self.env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index f195904..e18be12 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -64,7 +64,7 @@ def add_rollout(self, rollout: Rollout): def add_sample(self, s: State, a: Action, r: float, n: State, f: bool, o: t.Optional[Observation] = None): - rollout = Rollout(np.array([s]), np.expand_dims(np.array([a]), 0), + rollout = Rollout(np.array([s]), np.array([a]), np.array([r], dtype=np.float32), np.array([n]), np.array([f]), np.array([o]) if o is not None else None) self.add_rollout(rollout) @@ -75,13 +75,11 @@ def can_sample(self, num: int): def sample( self, batch_size: int, - return_observation: bool = False, cluster_size: int = 1 ) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: # TODO: add warning if batch_size % cluster_size != 0 # FIXME: currently doesn't take into account discontinuations between between rollouts indeces = np.random.choice(len(self.states) - (cluster_size - 1), batch_size//cluster_size) indeces = np.stack([indeces + i for i in range(cluster_size)]).flatten(order='F') - o = self.states[indeces] if not return_observation else self.observations[indeces] - return o, self.actions[indeces], self.rewards[indeces], self.next_states[ + return self.states[indeces], self.actions[indeces], self.rewards[indeces], self.next_states[ indeces], self.is_finished[indeces] diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index e7ac46b..f6742fa 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -1,88 +1,64 @@ import typing as t -import gym import numpy as np -from dm_env import Environment as dmEnv +from unpackable import unpack from rl_sandbox.agents.random_agent import RandomAgent -from rl_sandbox.utils.dm_control import ActionDiscritizer, decode_dm_ts +from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout # FIXME: whole function duplicates a lot of code from main.py -def collect_rollout(env: gym.Env | dmEnv, +def collect_rollout(env: Env, agent: t.Optional[t.Any] = None, - obs_res: t.Optional[t.Tuple[int, int]] = None, - run_on_obs: bool = False, + collect_obs: bool = False ) -> Rollout: - if run_on_obs and obs_res is None: - raise RuntimeError("Run on pixels cannot be done without specified resolution") s, a, r, n, f, o = [], [], [], [], [], [] - # TODO: worth creating N+1 standard of env, which will incorporate - # gym/dm_control/etc to remove just bloated matching all over project - match env: - case gym.Env(): - state, _ = env.reset() - if run_on_obs is True: - raise RuntimeError("Run on pixels currently supported only for dm_control") - case dmEnv(): - state, _, terminated = decode_dm_ts(env.reset()) - if run_on_obs is True: - state = env.physics.render(*obs_res, camera_id=0) + state, _, terminated = unpack(env.reset()) if agent is None: agent = RandomAgent(env) - while not terminated: action = agent.get_action(state) - match env: - case gym.Env(): - new_state, reward, terminated, _, _ = env.step(action) - case dmEnv(): - new_state, reward, terminated = decode_dm_ts(env.step(action)) - if run_on_obs: - new_state = env.physics.render(*obs_res, camera_id=0) + new_state, reward, terminated = unpack(env.step(action)) s.append(state) - # FIXME: action discritezer should be defined once - action_disritizer = ActionDiscritizer(env.action_spec(), values_per_dim=10) - a.append(action_disritizer.discretize(action)) + a.append(action) r.append(reward) n.append(new_state) f.append(terminated) - if obs_res is not None and isinstance(env, dmEnv): - o.append(env.physics.render(*obs_res, camera_id=0)) + # FIXME: obs are not collected yet + # if collect_obs and isinstance(env, dmEnv): + # o.append(env.render()) state = new_state - match env: - case gym.Env(): - obs = np.stack(list(env.render())) if obs_res is not None else None - case dmEnv(): - obs = np.array(o) if obs_res is not None else None + obs = None + # match env: + # case gym.Env(): + # obs = np.stack(list(env.render())) if obs_res is not None else None + # case dmEnv(): + # obs = np.array(o) if obs_res is not None else None return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) -def collect_rollout_num(env: gym.Env, +def collect_rollout_num(env: Env, num: int, agent: t.Optional[t.Any] = None, - obs_res: t.Optional[t.Tuple[int, int]] = None, - run_on_obs: bool = False) -> t.List[Rollout]: + collect_obs: bool = False) -> t.List[Rollout]: # TODO: paralelyze rollouts = [] for _ in range(num): - rollouts.append(collect_rollout(env, agent, obs_res, run_on_obs)) + rollouts.append(collect_rollout(env, agent, obs_res)) return rollouts -def fillup_replay_buffer(env: gym.Env, +def fillup_replay_buffer(env: Env, rep_buffer: ReplayBuffer, - num: int, - obs_res: t.Optional[t.Tuple[int, int]] = None, - run_on_obs: bool = False): + num: int): # TODO: paralelyze while not rep_buffer.can_sample(num): - rep_buffer.add_rollout(collect_rollout(env, obs_res=obs_res, run_on_obs=run_on_obs)) + rep_buffer.add_rollout(collect_rollout(env, collect_obs=False)) From cc038c1ac37bba13865608598d0a6447300ec4f2 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 15 Nov 2022 11:12:09 +0000 Subject: [PATCH 012/106] Rewritten replay buffer to more memory efficient and fixed cluster sampling --- config/config.yaml | 1 + main.py | 22 +++--- rl_sandbox/utils/replay_buffer.py | 108 ++++++++++++++++++------------ tests/test_replay_buffer.py | 48 ++++++------- 4 files changed, 100 insertions(+), 79 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 3aa535f..f2264d0 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -9,6 +9,7 @@ device_type: cpu training: epochs: 5000 batch_size: 128 + gradient_steps_per_step: 4 validation: rollout_num: 5 diff --git a/main.py b/main.py index 6dbbec5..cacaaaf 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,6 @@ from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.metrics import MetricsEvaluator -from rl_sandbox.utils.dm_control import ActionDiscritizer from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer from rl_sandbox.utils.rollout_generation import (collect_rollout_num, @@ -21,8 +20,6 @@ def main(cfg: DictConfig): env: Env = hydra.utils.instantiate(cfg.env) - # TODO: As images take much more data, rewrite replay buffer to be - # more memory efficient buff = ReplayBuffer() fillup_replay_buffer(env, buff, cfg.training.batch_size) @@ -47,22 +44,23 @@ def main(cfg: DictConfig): terminated = False while not terminated: - # TODO: For dreamer, add noise for sampling - if np.random.random() > scheduler.step(): - action = exploration_agent.get_action(state) - else: - action = agent.get_action(state) + if global_step % cfg.training.gradient_steps_per_step == 0: + # TODO: For dreamer, add noise for sampling + if np.random.random() > scheduler.step(): + action = exploration_agent.get_action(state) + else: + action = agent.get_action(state) - new_state, reward, terminated = unpack(env.step(action)) + new_state, reward, terminated = unpack(env.step(action)) - buff.add_sample(state, action, reward, new_state, terminated) + buff.add_sample(state, action, reward, new_state, terminated) # NOTE: unintuitive that batch_size is now number of total # samples, but not amount of sequences for recurrent model s, a, r, n, f = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get('batch_cluster_size', 1)) - # NOTE: Dreamer makes 4 policy steps per gradient descent + # TODO: add checkpoint saver for model losses = agent.train(s, a, r, n, f) for loss_name, loss in losses.items(): writer.add_scalar(f'train/{loss_name}', loss, global_step) @@ -76,7 +74,7 @@ def main(cfg: DictConfig): writer.add_scalar(f'val/{metric_name}', metric, epoch_num) if cfg.validation.visualize: - rollouts = collect_rollout_num(visualized_env, 1, agent, obs_res=cfg.obs_res) + rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index e18be12..109853a 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -1,4 +1,5 @@ import typing as t +from collections import deque from dataclasses import dataclass import numpy as np @@ -28,58 +29,77 @@ class Rollout: class ReplayBuffer: def __init__(self, max_len=2_000): + self.rollouts: deque[Rollout] = deque() + self.rollouts_len: deque[int] = deque() + self.curr_rollout = None self.max_len = max_len - self.states: States = np.array([]) - self.actions: Actions = np.array([]) - self.rewards: Rewards = np.array([]) - self.next_states: States = np.array([]) - self.observations: t.Optional[Observations] + self.total_num = 0 + + def __len__(self): + return self.total_num def add_rollout(self, rollout: Rollout): - if len(self.states) == 0: - self.states = rollout.states - self.actions = rollout.actions - self.rewards = rollout.rewards - self.next_states = rollout.next_states - self.is_finished = rollout.is_finished - self.observations = rollout.observations + # NOTE: only last next state is stored, all others are induced + # from state on next step + rollout.next_states = np.expand_dims(rollout.next_states[-1], 0) + self.rollouts.append(rollout) + self.total_num += len(self.rollouts[-1].rewards) + self.rollouts_len.append(len(self.rollouts[-1].rewards)) + + while self.total_num >= self.max_len: + self.total_num -= self.rollouts_len[0] + self.rollouts_len.popleft() + self.rollouts.popleft() + + # Add sample expects that each subsequent sample + # will be continuation of last rollout util termination flag true + # is encountered + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + if self.curr_rollout is None: + self.curr_rollout = Rollout([s], [a], [r], None, [f]) else: - self.states = np.concatenate([self.states, rollout.states]) - self.actions = np.concatenate([self.actions, rollout.actions]) - self.rewards = np.concatenate([self.rewards, rollout.rewards]) - self.next_states = np.concatenate([self.next_states, rollout.next_states]) - self.is_finished = np.concatenate([self.is_finished, rollout.is_finished]) - if self.observations is not None: - self.observations = np.concatenate( - [self.observations, rollout.observations]) - - if len(self.states) >= self.max_len: - self.states = self.states[:self.max_len] - self.actions = self.actions[:self.max_len] - self.rewards = self.rewards[:self.max_len] - self.next_states = self.next_states[:self.max_len] - self.is_finished = self.is_finished[:self.max_len] - if self.observations is not None: - self.observations = self.observations[:self.max_len] - - def add_sample(self, s: State, a: Action, r: float, n: State, f: bool, - o: t.Optional[Observation] = None): - rollout = Rollout(np.array([s]), np.array([a]), - np.array([r], dtype=np.float32), np.array([n]), np.array([f]), - np.array([o]) if o is not None else None) - self.add_rollout(rollout) + self.curr_rollout.states.append(s) + self.curr_rollout.actions.append(a) + self.curr_rollout.rewards.append(r) + self.curr_rollout.is_finished.append(f) + + if f: + self.curr_rollout = None + self.add_rollout( + Rollout(np.array(self.curr_rollout.states), + np.array(self.curr_rollout.actions), + np.array(self.curr_rollout.rewards, dtype=np.float32), + np.array([n]), np.array(self.curr_rollout.is_finished))) def can_sample(self, num: int): - return len(self.states) >= num + return self.total_num >= num def sample( self, batch_size: int, cluster_size: int = 1 - ) -> t.Tuple[States, Actions, Rewards, States, TerminationFlags]: - # TODO: add warning if batch_size % cluster_size != 0 - # FIXME: currently doesn't take into account discontinuations between between rollouts - indeces = np.random.choice(len(self.states) - (cluster_size - 1), batch_size//cluster_size) - indeces = np.stack([indeces + i for i in range(cluster_size)]).flatten(order='F') - return self.states[indeces], self.actions[indeces], self.rewards[indeces], self.next_states[ - indeces], self.is_finished[indeces] + ) -> tuple[States, Actions, Rewards, States, TerminationFlags]: + seq_num = batch_size // cluster_size + # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me + s, a, r, n, t = [], [], [], [], [] + r_indeces = np.random.choice(len(self.rollouts), + seq_num, + p=np.array(self.rollouts_len) / self.total_num) + for r_idx in r_indeces: + # NOTE: maybe just no add such small rollouts to buffer + assert self.rollouts_len[r_idx] - cluster_size + 1 > 0, "Rollout it too small" + s_idx = np.random.choice(self.rollouts_len[r_idx] - cluster_size + 1, 1).item() + + s.append(self.rollouts[r_idx].states[s_idx:s_idx + cluster_size]) + a.append(self.rollouts[r_idx].actions[s_idx:s_idx + cluster_size]) + r.append(self.rollouts[r_idx].rewards[s_idx:s_idx + cluster_size]) + t.append(self.rollouts[r_idx].is_finished[s_idx:s_idx + cluster_size]) + if s_idx != self.rollouts_len[r_idx] - cluster_size: + n.append(self.rollouts[r_idx].states[s_idx+1:s_idx+1 + cluster_size]) + else: + if cluster_size != 1: + n.append(self.rollouts[r_idx].states[s_idx+1:s_idx+1 + cluster_size - 1]) + n.append(self.rollouts[r_idx].next_states) + + return (np.concatenate(s), np.concatenate(a), np.concatenate(r), + np.concatenate(n), np.concatenate(t)) diff --git a/tests/test_replay_buffer.py b/tests/test_replay_buffer.py index a08ab56..392db93 100644 --- a/tests/test_replay_buffer.py +++ b/tests/test_replay_buffer.py @@ -1,5 +1,3 @@ -import random - import numpy as np from pytest import fixture @@ -12,7 +10,7 @@ def rep_buf(): def test_creation(rep_buf: ReplayBuffer): - assert len(rep_buf.states) == 0 + assert len(rep_buf) == 0 def test_adding(rep_buf: ReplayBuffer): @@ -23,9 +21,7 @@ def test_adding(rep_buf: ReplayBuffer): f = np.zeros((3), dtype=np.bool8) rep_buf.add_rollout(Rollout(s, a, r, n, f)) - assert len(rep_buf.states) == 3 - assert len(rep_buf.actions) == 3 - assert len(rep_buf.rewards) == 3 + assert len(rep_buf) == 3 s = np.zeros((3, 8)) a = np.zeros((3, 3), dtype=np.int32) @@ -34,9 +30,7 @@ def test_adding(rep_buf: ReplayBuffer): f = np.zeros((3), dtype=np.bool8) rep_buf.add_rollout(Rollout(s, a, r, n, f)) - assert len(rep_buf.states) == 6 - assert len(rep_buf.actions) == 6 - assert len(rep_buf.rewards) == 6 + assert len(rep_buf) == 6 def test_can_sample(rep_buf: ReplayBuffer): @@ -63,23 +57,31 @@ def test_sampling(rep_buf: ReplayBuffer): Rollout(np.ones((1, 3)), np.ones((1, 2), dtype=np.int32), i * np.ones((1)), np.ones((3, 8)), np.zeros((3), dtype=np.bool8))) - random.seed(42) + np.random.seed(42) _, _, r, _, _ = rep_buf.sample(3) - assert (r == [1, 0, 3]).all() + assert (r == [1, 4, 3]).all() def test_cluster_sampling(rep_buf: ReplayBuffer): for i in range(5): rep_buf.add_rollout( - Rollout(np.ones((1, 3)), np.ones((1, 2), dtype=np.int32), i * np.ones((1)), - np.ones((3, 8)), np.zeros((3), dtype=np.bool8))) - - random.seed(42) - _, _, r, _, _ = rep_buf.sample(4, cluster_size=2) - assert (r == [1, 2, 3, 4]).all() - - _, _, r, _, _ = rep_buf.sample(4, cluster_size=2) - assert (r == [0, 1, 1, 2]).all() - - _, _, r, _, _ = rep_buf.sample(4, cluster_size=2) - assert (r == [2, 3, 2, 3]).all() + Rollout(np.stack([np.arange(3, dtype=np.float32) for _ in range(3)]).T, + np.ones((3, 2), dtype=np.int32), i * np.ones((3)), + np.stack([np.arange(1, 4, dtype=np.float32) for _ in range(3)]).T, + np.zeros((3), dtype=np.bool8))) + + np.random.seed(42) + s, _, r, n, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [1, 1, 4, 4]).all() + assert (s[:, 0] == [0, 1, 1, 2]).all() + assert (n[:, 0] == [1, 2, 2, 3]).all() + + s, _, r, n, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [2, 2, 0, 0]).all() + assert (s[:, 0] == [0, 1, 0, 1]).all() + assert (n[:, 0] == [1, 2, 1, 2]).all() + + s, _, r, n, _ = rep_buf.sample(4, cluster_size=2) + assert (r == [0, 0, 4, 4]).all() + assert (s[:, 0] == [1, 2, 1, 2]).all() + assert (n[:, 0] == [2, 3, 2, 3]).all() From 2266ea533f300c53bbeaeae9308566ffc8f02f95 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 15 Nov 2022 13:25:20 +0000 Subject: [PATCH 013/106] Small improvements & fixes --- config/agent/dreamer_v2.yaml | 2 +- config/config.yaml | 4 +- config/env/dm_cartpole.yaml | 1 + main.py | 16 ++++- rl_sandbox/agents/dreamer_v2.py | 90 ++++++++++++++++++++++---- rl_sandbox/agents/random_agent.py | 3 + rl_sandbox/agents/rl_agent.py | 4 ++ rl_sandbox/utils/env.py | 23 ++++--- rl_sandbox/utils/replay_buffer.py | 3 +- rl_sandbox/utils/rollout_generation.py | 9 +-- 10 files changed, 122 insertions(+), 33 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index 7682452..86abe78 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -1,6 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 # World model parameters -batch_cluster_size: 8 +batch_cluster_size: 16 latent_dim: 32 latent_classes: 32 rssm_dim: 200 diff --git a/config/config.yaml b/config/config.yaml index f2264d0..0f5a117 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -8,8 +8,10 @@ device_type: cpu training: epochs: 5000 - batch_size: 128 + batch_size: 64 gradient_steps_per_step: 4 + save_checkpoint_every: 1000 + val_logs_every: 5 validation: rollout_num: 5 diff --git a/config/env/dm_cartpole.yaml b/config/env/dm_cartpole.yaml index 14a1a59..f676bab 100644 --- a/config/env/dm_cartpole.yaml +++ b/config/env/dm_cartpole.yaml @@ -3,6 +3,7 @@ domain_name: cartpole task_name: swingup run_on_pixels: true obs_res: [64, 64] +repeat_action_num: 25 transforms: - _target_: rl_sandbox.utils.env.ActionNormalizer - _target_: rl_sandbox.utils.env.ActionDisritezer diff --git a/main.py b/main.py index cacaaaf..05a8634 100644 --- a/main.py +++ b/main.py @@ -20,15 +20,20 @@ def main(cfg: DictConfig): env: Env = hydra.utils.instantiate(cfg.env) + # TODO: add replay buffer implementation, which stores rollouts + # on disk buff = ReplayBuffer() fillup_replay_buffer(env, buff, cfg.training.batch_size) metrics_evaluator = MetricsEvaluator() + # TODO: Implement smarter techniques for exploration + # (Plan2Explore, etc) exploration_agent = RandomAgent(env) agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape[0], - actions_num=env.action_space.shape[0], + # FIXME: feels bad + actions_num=(env.action_space.high - env.action_space.low + 1).item(), device_type=cfg.device_type) writer = SummaryWriter() @@ -60,14 +65,13 @@ def main(cfg: DictConfig): s, a, r, n, f = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get('batch_cluster_size', 1)) - # TODO: add checkpoint saver for model losses = agent.train(s, a, r, n, f) for loss_name, loss in losses.items(): writer.add_scalar(f'train/{loss_name}', loss, global_step) global_step += 1 ### Validation - if epoch_num % 100 == 0: + if epoch_num % cfg.training.val_logs_every == 0: rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) metrics = metrics_evaluator.calculate_metrics(rollouts) for metric_name, metric in metrics.items(): @@ -79,6 +83,12 @@ def main(cfg: DictConfig): for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) writer.add_video('val/visualization', video, epoch_num) + # FIXME:Very bad from architecture point + agent.viz_log(rollout, writer, epoch_num) + + ### Checkpoint + # if epoch_num % cfg.training.save_checkpoint_every == 0: + # agent.save_ckpt(epoch_num, losses) if __name__ == "__main__": diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 764f8de..6ffab32 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -147,7 +147,14 @@ def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, # Use zero vector for prev_state of first if h_prev is None: - h_prev = (torch.zeros((*action.shape[:-1], self.hidden_size)), torch.zeros((*action.shape[:-1], self.latent_dim * self.latent_classes))) + h_prev = (torch.zeros(( + *action.shape[:-1], + self.hidden_size, + ), + device=next(self.stoch_net.parameters()).device), + torch.zeros( + (*action.shape[:-1], self.latent_dim * self.latent_classes), + device=next(self.stoch_net.parameters()).device)) deter_prev, stoch_prev = h_prev determ, prior_stoch_dist = self.predict_next(stoch_prev, action, @@ -220,6 +227,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.latent_dim = latent_dim self.latent_classes = latent_classes self.cluster_size = batch_cluster_size + self.actions_num = actions_num # kl loss balancing (prior/posterior) self.alpha = 0.8 @@ -261,7 +269,8 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor def get_latent(self, obs: torch.Tensor, action, state): embed = self.encoder(obs) - determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), action) + determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), + action) latent_repr = latent_repr_dist.rsample().reshape(-1, 32 * 32) return determ, latent_repr.unsqueeze(0) @@ -273,7 +282,7 @@ def train(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, embed = embed.view(b // self.cluster_size, self.cluster_size, -1) obs = obs.view(-1, self.cluster_size, 3, h, w) - a = a.view(-1, self.cluster_size, 1) + a = a.view(-1, self.cluster_size, self.actions_num) r = r.view(-1, self.cluster_size, 1) f = is_finished.view(-1, self.cluster_size, 1) @@ -315,7 +324,8 @@ def KL(dist1, dist2): h_prev = [determ_t, posterior_stoch.unsqueeze(0)] latent_vars.append(posterior_stoch.detach()) - loss = torch.Tensor(1) + # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device + loss = torch.Tensor(1).to(next(self.encoder.parameters()).device) for l in losses.values(): loss += l @@ -325,7 +335,8 @@ def KL(dist1, dist2): discovered_latents = torch.stack(latent_vars).reshape( -1, self.latent_dim * self.latent_classes) - return {l: val.detach() for l, val in losses.items()}, discovered_latents + return {l: val.detach().cpu().item() + for l, val in losses.items()}, discovered_latents class ImaginativeCritic(nn.Module): @@ -393,7 +404,7 @@ def __init__( device_type: str = 'cpu'): self._state = None - self._last_action = torch.zeros(actions_num) + self._last_action = torch.zeros(actions_num, device=device_type) self.actions_num = actions_num self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size @@ -411,12 +422,12 @@ def __init__( 400, 4, intermediate_activation=nn.ELU, - final_activation=Quantize) - # TODO: Leave only ImaginativeCritic and move Actor to DreamerV2 + final_activation=Quantize).to(device_type) self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, critic_value_target_lambda, - latent_dim * latent_classes, actions_num) + latent_dim * latent_classes, + actions_num).to(device_type) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=4e-5) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-4) @@ -439,7 +450,8 @@ def imagine_trajectory( def reset(self): self._state = None - self._last_action = torch.zeros((1, 1, self.actions_num)) + self._last_action = torch.zeros((1, 1, self.actions_num), + device=next(self.world_model.parameters()).device) def preprocess_obs(self, obs: torch.Tensor): order = list(range(len(obs.shape))) @@ -457,6 +469,41 @@ def get_action(self, obs: Observation) -> Action: return np.array([self._last_action.squeeze().detach().cpu().numpy().argmax()]) + def _generate_video(self, obs: Observation, init_action: Action): + obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) + obs = self.preprocess_obs(obs).unsqueeze(0) + + action = F.one_hot(self.from_np(init_action).to(torch.int64), + num_classes=self.actions_num).squeeze() + z_0 = self.world_model.get_latent(obs, action.unsqueeze(0).unsqueeze(0), None)[1] + imagined_rollout = self.imagine_trajectory(z_0.squeeze()) + zs, _, _, _, _ = zip(*imagined_rollout) + reconstructed_plan = [] + for z in zs: + reconstructed_plan.append( + self.world_model.image_predictor(z).detach().numpy()) + video_r = np.concatenate(reconstructed_plan) + video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) + return video_r + + def viz_log(self, rollout, logger, epoch_num): + init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 3) + videos_r = np.concatenate([ + self._generate_video(obs_0.copy(), a_0) for obs_0, a_0 in zip( + rollout.states[init_indeces], rollout.actions[init_indeces]) + ], + axis=3) + + videos = np.concatenate([ + rollout.states[init_idx:init_idx + self.imagination_horizon].transpose( + 0, 3, 1, 2) for init_idx in init_indeces + ], + axis=3) + + logger.add_video('val/dreamed_rollout', + np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0), + epoch_num) + def from_np(self, arr: np.ndarray): return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) @@ -464,7 +511,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation is_finished: TerminationFlags): obs = self.preprocess_obs(self.from_np(obs)) - a = self.from_np(a) + a = self.from_np(a).to(torch.int64) + # Works incorrect a = F.one_hot(a, num_classes=self.actions_num).squeeze() r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) @@ -490,8 +538,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation reduction='sum') losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - losses_ac['loss_actor_dynamics_backprop'] += (-(1 - self.rho) * vs[-1]).mean() - losses_ac['loss_actor_entropy'] += -self.eta * torch.stack( + losses_ac['loss_actor_dynamics_backprop'] += -((1 - self.rho) * vs).mean() + losses_ac['loss_actor_entropy'] += self.eta * torch.stack( [a.entropy() for a in action_dists[:-1]]).mean() losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] @@ -504,6 +552,20 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation self.critic_optimizer.step() self.critic.update_target() - losses_ac = {l: val.detach() for l, val in losses_ac.items()} + losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} return losses | losses_ac + + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + torch.save( + { + 'epoch': epoch_num, + 'world_model_state_dict': self.world_model.state_dict(), + 'world_model_optimizer_state_dict': + self.world_model.optimizer.state_dict(), + 'actor_state_dict': self.actor.state_dict(), + 'critic_state_dict': self.critic.state_dict(), + 'actor_optimizer_state_dict': self.actor_optimizer.state_dict(), + 'critic_optimizer_state_dict': self.critic_optimizer.state_dict(), + 'losses': losses + }, f'dreamerV2-{epoch_num}-{sum(losses.values())}.ckpt') diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py index b1dffb4..eab3d70 100644 --- a/rl_sandbox/agents/random_agent.py +++ b/rl_sandbox/agents/random_agent.py @@ -16,3 +16,6 @@ def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): pass + + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + pass diff --git a/rl_sandbox/agents/rl_agent.py b/rl_sandbox/agents/rl_agent.py index b159db3..3f0ba09 100644 --- a/rl_sandbox/agents/rl_agent.py +++ b/rl_sandbox/agents/rl_agent.py @@ -19,3 +19,7 @@ def train(self, s: States, a: Actions, r: Rewards, next: States) -> dict[str, An # properly reseted between rollouts def reset(self): pass + + @abstractmethod + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + pass diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 200ecf9..bd5fbe8 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -81,24 +81,29 @@ def transform_action(self, action: NDArray[Shape['*'], return np.array(a) def transform_space(self, space: gym.spaces.Box): - return gym.spaces.Box(0, self.per_dim**len(self.low), dtype=np.uint32) + return gym.spaces.Box(0, self.per_dim**len(self.low)-1, dtype=np.int32) class Env(metaclass=ABCMeta): def __init__(self, run_on_pixels: bool, obs_res: tuple[int, int], - transforms: list[ActionTransformer]): + repeat_action_num: int, transforms: list[ActionTransformer]): self.obs_res = obs_res self.run_on_pixels = run_on_pixels + self.repeat_action_num = repeat_action_num + assert self.repeat_action_num >= 1 self.ac_trans = [] for t in transforms: t.set_env(self) self.ac_trans.append(t) def step(self, action: Action) -> EnvStepResult: + action = action.copy() for t in reversed(self.ac_trans): action = t.transform_action(action) - return self._step(action) + for _ in range(self.repeat_action_num): + res = self._step(action) + return res @abstractmethod def _step(self, action: Action) -> EnvStepResult: @@ -131,8 +136,8 @@ def action_space(self) -> gym.Space: class GymEnv(Env): def __init__(self, task_name: str, run_on_pixels: bool, obs_res: tuple[int, int], - transforms: list[ActionTransformer]): - super().__init__(run_on_pixels, obs_res, transforms) + repeat_action_num: int, transforms: list[ActionTransformer]): + super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) self.env: gym.Env = gym.make(task_name) self.visualized_env: gym.Env = gym.make(task_name, render_mode='rgb_array_list') @@ -158,11 +163,11 @@ def _action_space(self): class DmEnv(Env): - def __init__(self, run_on_pixels: bool, obs_res: tuple[int, int], domain_name: str, - task_name: str, - transforms: list[ActionTransformer]): + def __init__(self, run_on_pixels: bool, obs_res: tuple[int, + int], repeat_action_num: int, + domain_name: str, task_name: str, transforms: list[ActionTransformer]): self.env: dmEnviron = suite.load(domain_name=domain_name, task_name=task_name) - super().__init__(run_on_pixels, obs_res, transforms) + super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) def render(self): return self.env.physics.render(*self.obs_res, camera_id=0) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 109853a..ea17909 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -26,6 +26,7 @@ class Rollout: observations: t.Optional[Observations] = None +# TODO: make buffer concurrent-friendly class ReplayBuffer: def __init__(self, max_len=2_000): @@ -64,12 +65,12 @@ def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): self.curr_rollout.is_finished.append(f) if f: - self.curr_rollout = None self.add_rollout( Rollout(np.array(self.curr_rollout.states), np.array(self.curr_rollout.actions), np.array(self.curr_rollout.rewards, dtype=np.float32), np.array([n]), np.array(self.curr_rollout.is_finished))) + self.curr_rollout = None def can_sample(self, num: int): return self.total_num >= num diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index f6742fa..1b74c65 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -32,9 +32,10 @@ def collect_rollout(env: Env, n.append(new_state) f.append(terminated) - # FIXME: obs are not collected yet + # FIXME: will break for non-DM + if collect_obs: + o.append(env.render()) # if collect_obs and isinstance(env, dmEnv): - # o.append(env.render()) state = new_state obs = None @@ -42,7 +43,7 @@ def collect_rollout(env: Env, # case gym.Env(): # obs = np.stack(list(env.render())) if obs_res is not None else None # case dmEnv(): - # obs = np.array(o) if obs_res is not None else None + obs = np.array(o) if collect_obs is not None else None return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) def collect_rollout_num(env: Env, @@ -52,7 +53,7 @@ def collect_rollout_num(env: Env, # TODO: paralelyze rollouts = [] for _ in range(num): - rollouts.append(collect_rollout(env, agent, obs_res)) + rollouts.append(collect_rollout(env, agent, collect_obs)) return rollouts From 24923e999b3ad43e8dbb91240dad5b82ca4a1777 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 15 Nov 2022 21:27:05 +0000 Subject: [PATCH 014/106] Added dist probs & planning visualization --- config/agent/dreamer_v2.yaml | 16 ++-- config/config.yaml | 2 +- rl_sandbox/agents/dreamer_v2.py | 110 ++++++++++++++++--------- rl_sandbox/utils/rollout_generation.py | 5 +- 4 files changed, 85 insertions(+), 48 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index 86abe78..806e19f 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -1,20 +1,26 @@ _target_: rl_sandbox.agents.DreamerV2 # World model parameters -batch_cluster_size: 16 +batch_cluster_size: 32 latent_dim: 32 latent_classes: 32 rssm_dim: 200 kl_loss_scale: 0.1 +kl_loss_balancing: 0.8 +world_model_lr: 3e-4 # ActorCritic parameters discount_factor: 0.995 imagination_horizon: 15 + +actor_lr: 8e-5 +# mixing of reinforce and maximizing value func +# for dm_control it is zero in Dreamer (Atari 1) +actor_reinforce_fraction: 0 +actor_entropy_scale: 1e-4 + +critic_lr: 8e-5 # Lambda parameter for trainin deeper multi-step prediction critic_value_target_lambda: 0.95 critic_update_interval: 100 # [0-1], 1 means hard update critic_soft_update_fraction: 1 -# mixing of reinforce and maximizing value func -# for dm_control it is zero in Dreamer (Atari 1) -actor_reinforce_fraction: 0 -actor_entropy_scale: 1e-4 diff --git a/config/config.yaml b/config/config.yaml index 0f5a117..31275b2 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -8,7 +8,7 @@ device_type: cpu training: epochs: 5000 - batch_size: 64 + batch_size: 128 gradient_steps_per_step: 4 save_checkpoint_every: 1000 val_logs_every: 5 diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 6ffab32..d01dcde 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -2,6 +2,7 @@ import typing as t from collections import defaultdict +import matplotlib.pyplot as plt import numpy as np import torch from torch import nn @@ -220,7 +221,7 @@ def forward(self, X): class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale): + actions_num, kl_loss_scale, kl_loss_balancing): super().__init__() self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim @@ -229,7 +230,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.cluster_size = batch_cluster_size self.actions_num = actions_num # kl loss balancing (prior/posterior) - self.alpha = 0.8 + self.alpha = kl_loss_balancing self.recurrent_model = RSSM(latent_dim, rssm_dim, @@ -249,12 +250,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, intermediate_activation=nn.ELU, final_activation=nn.Sigmoid) - self.optimizer = torch.optim.Adam(itertools.chain( - self.recurrent_model.parameters(), self.encoder.parameters(), - self.image_predictor.parameters(), self.reward_predictor.parameters(), - self.discount_predictor.parameters()), - lr=2e-4) - def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): determ_state, next_repr_dist = self.recurrent_model.predict_next( latent_repr.unsqueeze(0), action.unsqueeze(0), world_state) @@ -271,11 +266,10 @@ def get_latent(self, obs: torch.Tensor, action, state): embed = self.encoder(obs) determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), action) - latent_repr = latent_repr_dist.rsample().reshape(-1, 32 * 32) - return determ, latent_repr.unsqueeze(0) + return determ, latent_repr_dist - def train(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - is_finished: torch.Tensor): + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + is_finished: torch.Tensor): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -324,19 +318,9 @@ def KL(dist1, dist2): h_prev = [determ_t, posterior_stoch.unsqueeze(0)] latent_vars.append(posterior_stoch.detach()) - # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device - loss = torch.Tensor(1).to(next(self.encoder.parameters()).device) - for l in losses.values(): - loss += l - - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - discovered_latents = torch.stack(latent_vars).reshape( -1, self.latent_dim * self.latent_classes) - return {l: val.detach().cpu().item() - for l, val in losses.items()}, discovered_latents + return losses, discovered_latents class ImaginativeCritic(nn.Module): @@ -395,16 +379,18 @@ def __init__( rssm_dim: int, discount_factor: float, kl_loss_scale: float, + kl_loss_balancing: float, imagination_horizon: int, critic_update_interval: int, actor_reinforce_fraction: float, actor_entropy_scale: float, critic_soft_update_fraction: float, critic_value_target_lambda: float, + world_model_lr: float, + actor_lr: float, + critic_lr: float, device_type: str = 'cpu'): - self._state = None - self._last_action = torch.zeros(actions_num, device=device_type) self.actions_num = actions_num self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size @@ -415,8 +401,8 @@ def __init__( self.eta = actor_entropy_scale self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, - rssm_dim, actions_num, - kl_loss_scale).to(device_type) + rssm_dim, actions_num, kl_loss_scale, + kl_loss_balancing).to(device_type) self.actor = fc_nn_generator(latent_dim * latent_classes, actions_num, 400, @@ -429,8 +415,19 @@ def __init__( latent_dim * latent_classes, actions_num).to(device_type) - self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=4e-5) - self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-4) + self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), + lr=world_model_lr, + eps=1e-5, + weight_decay=1e-6) + self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(), + lr=actor_lr, + eps=1e-5, + weight_decay=1e-6) + self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), + lr=critic_lr, + eps=1e-5, + weight_decay=1e-6) + self.reset() def imagine_trajectory( self, z_0 @@ -452,6 +449,9 @@ def reset(self): self._state = None self._last_action = torch.zeros((1, 1, self.actions_num), device=next(self.world_model.parameters()).device) + self._latent_probs = torch.zeros((32, 32)) + self._action_probs = torch.zeros((self.actions_num)) + self._stored_steps = 0 def preprocess_obs(self, obs: torch.Tensor): order = list(range(len(obs.shape))) @@ -464,8 +464,17 @@ def get_action(self, obs: Observation) -> Action: obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) obs = self.preprocess_obs(obs).unsqueeze(0) - self._state = self.world_model.get_latent(obs, self._last_action, self._state) - self._last_action = self.actor(self._state[1]).rsample() + determ, latent_repr_dist = self.world_model.get_latent(obs, self._last_action, + self._state) + self._state = (determ, latent_repr_dist.rsample().reshape(-1, + 32 * 32).unsqueeze(0)) + + actor_dist = self.actor(self._state[1]) + self._last_action = actor_dist.rsample() + + self._action_probs += actor_dist.probs.squeeze() + self._latent_probs += latent_repr_dist.probs.squeeze() + self._stored_steps += 1 return np.array([self._last_action.squeeze().detach().cpu().numpy().argmax()]) @@ -475,7 +484,10 @@ def _generate_video(self, obs: Observation, init_action: Action): action = F.one_hot(self.from_np(init_action).to(torch.int64), num_classes=self.actions_num).squeeze() - z_0 = self.world_model.get_latent(obs, action.unsqueeze(0).unsqueeze(0), None)[1] + z_0 = self.world_model.get_latent(obs, + action.unsqueeze(0).unsqueeze(0), + None)[1].rsample().reshape(-1, + 32 * 32).unsqueeze(0) imagined_rollout = self.imagine_trajectory(z_0.squeeze()) zs, _, _, _, _ = zip(*imagined_rollout) reconstructed_plan = [] @@ -499,10 +511,17 @@ def viz_log(self, rollout, logger, epoch_num): 0, 3, 1, 2) for init_idx in init_indeces ], axis=3) - - logger.add_video('val/dreamed_rollout', - np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0), - epoch_num) + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0) + latent_hist = (self._latent_probs/self._stored_steps*255.0).detach().cpu().numpy().reshape(-1, 32, 32) + action_hist = (self._action_probs/self._stored_steps).detach().cpu().numpy() + + # logger.add_histogram('val/action_probs', action_hist, epoch_num) + fig = plt.Figure() + ax = fig.add_axes([0,0,1,1]) + ax.bar(np.arange(self.actions_num), action_hist) + logger.add_figure('val/action_probs', fig, epoch_num) + logger.add_image('val/latent_probs', latent_hist, epoch_num) + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) def from_np(self, arr: np.ndarray): return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) @@ -512,14 +531,23 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation obs = self.preprocess_obs(self.from_np(obs)) a = self.from_np(a).to(torch.int64) - # Works incorrect a = F.one_hot(a, num_classes=self.actions_num).squeeze() r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) is_finished = self.from_np(is_finished) # take some latent embeddings as initial step - losses, discovered_latents = self.world_model.train(next_obs, a, r, is_finished) + losses, discovered_latents = self.world_model.calculate_loss( + next_obs, a, r, is_finished) + + # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device + world_model_loss = torch.Tensor(1).to(next(self.world_model.parameters()).device) + for l in losses.values(): + world_model_loss += l + + self.world_model_optimizer.zero_grad() + world_model_loss.backward() + self.world_model_optimizer.step() idx = torch.randperm(discovered_latents.size(0)) initial_states = discovered_latents[idx] @@ -538,7 +566,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation reduction='sum') losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - losses_ac['loss_actor_dynamics_backprop'] += -((1 - self.rho) * vs).mean() + losses_ac['loss_actor_dynamics_backprop'] += -( + (1 - self.rho) * vs[:-1]).mean() losses_ac['loss_actor_entropy'] += self.eta * torch.stack( [a.entropy() for a in action_dists[:-1]]).mean() losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ @@ -552,6 +581,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation self.critic_optimizer.step() self.critic.update_target() + losses = {l: val.detach().cpu().item() for l, val in losses.items()} losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} return losses | losses_ac @@ -562,7 +592,7 @@ def save_ckpt(self, epoch_num: int, losses: dict[str, float]): 'epoch': epoch_num, 'world_model_state_dict': self.world_model.state_dict(), 'world_model_optimizer_state_dict': - self.world_model.optimizer.state_dict(), + self.world_model_optimizer.state_dict(), 'actor_state_dict': self.actor.state_dict(), 'critic_state_dict': self.critic.state_dict(), 'actor_optimizer_state_dict': self.actor_optimizer.state_dict(), diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index 1b74c65..cde5f21 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -16,11 +16,12 @@ def collect_rollout(env: Env, s, a, r, n, f, o = [], [], [], [], [], [] - state, _, terminated = unpack(env.reset()) - if agent is None: agent = RandomAgent(env) + state, _, terminated = unpack(env.reset()) + agent.reset() + while not terminated: action = agent.get_action(state) From b5ed8b170d2a7d9efe8e58b24facdc81c09aced5 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 4 Dec 2022 22:56:10 +0000 Subject: [PATCH 015/106] Performance improvement --- main.py | 5 ++- pyproject.toml | 2 + rl_sandbox/agents/dreamer_v2.py | 75 ++++++++++++++++++--------------- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/main.py b/main.py index 05a8634..4a07b27 100644 --- a/main.py +++ b/main.py @@ -87,9 +87,10 @@ def main(cfg: DictConfig): agent.viz_log(rollout, writer, epoch_num) ### Checkpoint - # if epoch_num % cfg.training.save_checkpoint_every == 0: - # agent.save_ckpt(epoch_num, losses) + if epoch_num % cfg.training.save_checkpoint_every == 0: + agent.save_ckpt(epoch_num, losses) if __name__ == "__main__": main() + diff --git a/pyproject.toml b/pyproject.toml index dbd22d4..3e22666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,3 +25,5 @@ torch = '^1.12' tensorboard = '^2.0' dm-control = '^1.0.0' unpackable = '^0.0.4' +hydra-core = "^1.2.0" +matplotlib = "^3.0.0" diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index d01dcde..0f4f17c 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -336,12 +336,12 @@ def __init__(self, discount_factor: float, update_interval: int, self._update_num = 0 self.critic = fc_nn_generator(latent_dim, - actions_num, + 1, 400, 1, intermediate_activation=nn.ELU) self.target_critic = fc_nn_generator(latent_dim, - actions_num, + 1, 400, 1, intermediate_activation=nn.ELU) @@ -360,11 +360,13 @@ def estimate_value(self, z) -> torch.Tensor: def lambda_return(self, zs, rs, ts): v_lambdas = [self.target_critic(zs[-1])] - for r, z, t in zip(reversed(rs[:-1]), reversed(zs[:-1]), reversed(ts[:-1])): - v_lambda = r + t * self.gamma * ( - (1 - self.lambda_) * self.target_critic(z) + self.lambda_ * v_lambdas[-1]) + for i in range(zs.shape[0] - 2, -1, -1): + v_lambda = rs[i] + ts[i] * self.gamma * ( + (1 - self.lambda_) * self.target_critic(zs[i]).detach() + + self.lambda_ * v_lambdas[-1]) v_lambdas.append(v_lambda) - return torch.concat(list(reversed(v_lambdas)), dim=0) + + return torch.stack(list(reversed(v_lambdas))) class DreamerV2(RlAgent): @@ -431,19 +433,25 @@ def __init__( def imagine_trajectory( self, z_0 - ) -> list[tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, - torch.Tensor]]: - rollout = [] + ) -> tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, torch.Tensor, + torch.Tensor]: world_state = None - z = z_0.detach().unsqueeze(0) + zs, actions, next_zs, rewards, ts = [], [], [], [], [] + z = z_0.detach() for _ in range(self.imagination_horizon): a = self.actor(z) world_state, next_z, reward, is_finished = self.world_model.predict_next( z, a.rsample(), world_state) - rollout.append( - (z.detach(), a, next_z.detach(), reward.detach(), is_finished.detach())) + + zs.append(z) + actions.append(a) + next_zs.append(next_z) + rewards.append(reward) + ts.append(is_finished) + z = next_z.detach() - return rollout + return (torch.stack(zs).detach(), actions, torch.stack(next_zs).detach(), torch.stack(rewards).detach(), + torch.stack(ts).detach()) def reset(self): self._state = None @@ -488,8 +496,7 @@ def _generate_video(self, obs: Observation, init_action: Action): action.unsqueeze(0).unsqueeze(0), None)[1].rsample().reshape(-1, 32 * 32).unsqueeze(0) - imagined_rollout = self.imagine_trajectory(z_0.squeeze()) - zs, _, _, _, _ = zip(*imagined_rollout) + zs, _, _, _, _ = self.imagine_trajectory(z_0) reconstructed_plan = [] for z in zs: reconstructed_plan.append( @@ -512,12 +519,13 @@ def viz_log(self, rollout, logger, epoch_num): ], axis=3) videos_comparison = np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0) - latent_hist = (self._latent_probs/self._stored_steps*255.0).detach().cpu().numpy().reshape(-1, 32, 32) - action_hist = (self._action_probs/self._stored_steps).detach().cpu().numpy() + latent_hist = (self._latent_probs / self._stored_steps * + 255.0).detach().cpu().numpy().reshape(-1, 32, 32) + action_hist = (self._action_probs / self._stored_steps).detach().cpu().numpy() # logger.add_histogram('val/action_probs', action_hist, epoch_num) fig = plt.Figure() - ax = fig.add_axes([0,0,1,1]) + ax = fig.add_axes([0, 0, 1, 1]) ax.bar(np.arange(self.actions_num), action_hist) logger.add_figure('val/action_probs', fig, epoch_num) logger.add_image('val/latent_probs', latent_hist, epoch_num) @@ -555,23 +563,20 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac = defaultdict( lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) - for z_0 in initial_states: - rollout = self.imagine_trajectory(z_0) - zs, action_dists, next_zs, rewards, terminal_flags = zip(*rollout) - vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) - - losses_ac['loss_critic'] += F.mse_loss(self.critic.estimate_value( - torch.stack(next_zs).squeeze(1)), - vs.detach(), - reduction='sum') - - losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - losses_ac['loss_actor_dynamics_backprop'] += -( - (1 - self.rho) * vs[:-1]).mean() - losses_ac['loss_actor_entropy'] += self.eta * torch.stack( - [a.entropy() for a in action_dists[:-1]]).mean() - losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ - 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + zs, action_dists, next_zs, rewards, terminal_flags = self.imagine_trajectory( + initial_states) + vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) + + losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs), + vs.detach(), + reduction='sum') / zs.shape[1] + losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control + losses_ac['loss_actor_dynamics_backprop'] = -( + (1 - self.rho) * vs).sum() / zs.shape[1] + losses_ac['loss_actor_entropy'] += self.eta * torch.stack( + [a.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] + losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ + 'loss_actor_dynamics_backprop'] #+ losses_ac['loss_actor_entropy'] self.actor_optimizer.zero_grad() self.critic_optimizer.zero_grad() From 769e3082f858b0e03fbcf054703566f2f216fb2d Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 5 Dec 2022 13:19:57 +0000 Subject: [PATCH 016/106] Fixes --- rl_sandbox/agents/dreamer_v2.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 0f4f17c..03f7f12 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -42,9 +42,8 @@ def forward(self, x): class Quantize(nn.Module): def forward(self, logits): - dist = torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( + return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( logits=logits) - return dist class RSSM(nn.Module): @@ -450,8 +449,8 @@ def imagine_trajectory( ts.append(is_finished) z = next_z.detach() - return (torch.stack(zs).detach(), actions, torch.stack(next_zs).detach(), torch.stack(rewards).detach(), - torch.stack(ts).detach()) + return (torch.stack(zs), actions, torch.stack(next_zs), + torch.stack(rewards).detach(), torch.stack(ts).detach()) def reset(self): self._state = None @@ -497,10 +496,9 @@ def _generate_video(self, obs: Observation, init_action: Action): None)[1].rsample().reshape(-1, 32 * 32).unsqueeze(0) zs, _, _, _, _ = self.imagine_trajectory(z_0) - reconstructed_plan = [] - for z in zs: - reconstructed_plan.append( - self.world_model.image_predictor(z).detach().numpy()) + reconstructed_plan = [ + self.world_model.image_predictor(z).detach().numpy() for z in zs + ] video_r = np.concatenate(reconstructed_plan) video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) return video_r @@ -573,10 +571,10 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control losses_ac['loss_actor_dynamics_backprop'] = -( (1 - self.rho) * vs).sum() / zs.shape[1] - losses_ac['loss_actor_entropy'] += self.eta * torch.stack( + losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( [a.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ - 'loss_actor_dynamics_backprop'] #+ losses_ac['loss_actor_entropy'] + 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] self.actor_optimizer.zero_grad() self.critic_optimizer.zero_grad() From 5bbc60c3b35c5457cabd55ea358d4cf8595e3024 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 6 Dec 2022 00:57:15 +0000 Subject: [PATCH 017/106] Fixes --- config/config.yaml | 2 +- rl_sandbox/agents/dreamer_v2.py | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 31275b2..b6e9c26 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -8,7 +8,7 @@ device_type: cpu training: epochs: 5000 - batch_size: 128 + batch_size: 1024 gradient_steps_per_step: 4 save_checkpoint_every: 1000 val_logs_every: 5 diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 03f7f12..137bc43 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -81,16 +81,19 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): nn.Linear(latent_dim * latent_classes + actions_num, hidden_size), # Dreamer 'img_in' nn.LayerNorm(hidden_size), + nn.ELU(inplace=True) ) self.determ_recurrent = nn.GRU(input_size=hidden_size, hidden_size=hidden_size) # Dreamer gru '_cell' # Calculate stochastic state from prior embed # shared between all ensemble models + # FIXME: check whether it is trully correct self.ensemble_prior_estimator = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' nn.LayerNorm(hidden_size), + nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' View((-1, latent_dim, self.latent_classes)), @@ -101,10 +104,9 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): # FIXME: very bad magic number img_sz = 4 * 384 # 384*2x2 self.stoch_net = nn.Sequential( - nn.Linear(hidden_size + img_sz, hidden_size), - nn.LayerNorm(hidden_size), - nn.Linear(hidden_size, hidden_size), # Dreamer 'obs_out' + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' nn.LayerNorm(hidden_size), + nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((-1, latent_dim, self.latent_classes)), @@ -116,18 +118,20 @@ def estimate_stochastic_latent(self, prev_determ): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of # taking only one random between all ensembles + # FIXME: temporary use the same model idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[idx] + return dists_per_model[0] def predict_next(self, stoch_latent, action, deter_state: t.Optional[torch.Tensor] = None): + # FIXME: Move outside of rssm to omit checking if deter_state is None: deter_state = torch.zeros(*stoch_latent.shape[:2], self.hidden_size).to( next(self.stoch_net.parameters()).device) x = self.pre_determ_recurrent(torch.concat([stoch_latent, action], dim=2)) - _, determ = self.determ_recurrent(x, deter_state) + _, determ = self.determ_recurrent(x, h_0=deter_state) # used for KL divergence predicted_stoch_latent = self.estimate_stochastic_latent(determ) @@ -145,7 +149,8 @@ def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, Returns 'h_next' <- the next next of the world """ - # Use zero vector for prev_state of first + # FIXME: Use zero vector for prev_state of first + # Move outside of rssm to omit checking if h_prev is None: h_prev = (torch.zeros(( *action.shape[:-1], @@ -189,12 +194,12 @@ def forward(self, X): class Decoder(nn.Module): - def __init__(self, kernel_sizes=[5, 5, 6, 6]): + def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): super().__init__() layers = [] self.channel_step = 48 # 2**(len(kernel_sizes)-1)*channel_step - self.convin = nn.Linear(32 * 32, 32 * self.channel_step) + self.convin = nn.Linear(input_size, 32 * self.channel_step) in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step for i, k in enumerate(kernel_sizes): @@ -236,7 +241,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, latent_classes=latent_classes) self.encoder = Encoder() - self.image_predictor = Decoder() + self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes) self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, @@ -302,11 +307,11 @@ def KL(dist1, dist2): -1, self.latent_dim * self.latent_classes) r_t_pred = self.reward_predictor( - torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1).detach()) f_t_pred = self.discount_predictor( torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - x_r = self.image_predictor(posterior_stoch) + x_r = self.image_predictor([determ_t.squeeze(0), posterior_stoch]) losses['loss_reconstruction'] = nn.functional.mse_loss(x_t, x_r) losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) From 03a0ee74007d71f42fd1340fe8b165bd52a12463 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 6 Dec 2022 13:07:38 +0000 Subject: [PATCH 018/106] Various fixes and refactorings --- config/env/dm_cartpole.yaml | 2 +- main.py | 57 ++++++++++++-------------- rl_sandbox/agents/dreamer_v2.py | 34 +++++++-------- rl_sandbox/agents/explorative_agent.py | 28 +++++++++++++ rl_sandbox/agents/random_agent.py | 2 +- rl_sandbox/agents/rl_agent.py | 4 +- rl_sandbox/utils/env.py | 25 +++++++---- rl_sandbox/utils/replay_buffer.py | 45 ++++++++++++-------- rl_sandbox/utils/rollout_generation.py | 36 +++++++++------- 9 files changed, 141 insertions(+), 92 deletions(-) create mode 100644 rl_sandbox/agents/explorative_agent.py diff --git a/config/env/dm_cartpole.yaml b/config/env/dm_cartpole.yaml index f676bab..2fcabf9 100644 --- a/config/env/dm_cartpole.yaml +++ b/config/env/dm_cartpole.yaml @@ -3,7 +3,7 @@ domain_name: cartpole task_name: swingup run_on_pixels: true obs_res: [64, 64] -repeat_action_num: 25 +repeat_action_num: 5 transforms: - _target_: rl_sandbox.utils.env.ActionNormalizer - _target_: rl_sandbox.utils.env.ActionDisritezer diff --git a/main.py b/main.py index 4a07b27..a5c132c 100644 --- a/main.py +++ b/main.py @@ -6,10 +6,11 @@ from unpackable import unpack from rl_sandbox.agents.random_agent import RandomAgent +from rl_sandbox.agents.explorative_agent import ExplorativeAgent from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer -from rl_sandbox.utils.rollout_generation import (collect_rollout_num, +from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, fillup_replay_buffer) from rl_sandbox.utils.schedulers import LinearScheduler @@ -29,46 +30,40 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - exploration_agent = RandomAgent(env) - agent = hydra.utils.instantiate(cfg.agent, + + policy_agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape[0], # FIXME: feels bad actions_num=(env.action_space.high - env.action_space.low + 1).item(), device_type=cfg.device_type) - + agent = ExplorativeAgent( + policy_agent, + # TODO: For dreamer, add noise for sampling instead + # of just random actions + RandomAgent(env), + LinearScheduler(0.9, 0.01, 5_000)) writer = SummaryWriter() - scheduler = LinearScheduler(0.9, 0.01, 5_000) - global_step = 0 - for epoch_num in tqdm(range(cfg.training.epochs)): + pbar = tqdm(total=cfg.training.epochs*200) + for epoch_num in range(cfg.training.epochs): ### Training and exploration - state, _, _ = unpack(env.reset()) - agent.reset() + # TODO: add buffer end prioritarization + for s, a, r, n, f, _ in iter_rollout(env, agent): + buff.add_sample(s, a, r, n, f) - terminated = False - while not terminated: if global_step % cfg.training.gradient_steps_per_step == 0: - # TODO: For dreamer, add noise for sampling - if np.random.random() > scheduler.step(): - action = exploration_agent.get_action(state) - else: - action = agent.get_action(state) - - new_state, reward, terminated = unpack(env.step(action)) - - buff.add_sample(state, action, reward, new_state, terminated) - - # NOTE: unintuitive that batch_size is now number of total - # samples, but not amount of sequences for recurrent model - s, a, r, n, f = buff.sample(cfg.training.batch_size, - cluster_size=cfg.agent.get('batch_cluster_size', 1)) - - losses = agent.train(s, a, r, n, f) - for loss_name, loss in losses.items(): - writer.add_scalar(f'train/{loss_name}', loss, global_step) + # NOTE: unintuitive that batch_size is now number of total + # samples, but not amount of sequences for recurrent model + s, a, r, n, f = buff.sample(cfg.training.batch_size, + cluster_size=cfg.agent.get('batch_cluster_size', 1)) + + losses = agent.train(s, a, r, n, f) + for loss_name, loss in losses.items(): + writer.add_scalar(f'train/{loss_name}', loss, global_step) global_step += 1 + pbar.update(1) ### Validation if epoch_num % cfg.training.val_logs_every == 0: @@ -83,8 +78,8 @@ def main(cfg: DictConfig): for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) writer.add_video('val/visualization', video, epoch_num) - # FIXME:Very bad from architecture point - agent.viz_log(rollout, writer, epoch_num) + # FIXME: Very bad from architecture point + agent.policy_ag.viz_log(rollout, writer, epoch_num) ### Checkpoint if epoch_num % cfg.training.save_checkpoint_every == 0: diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 137bc43..8b5c983 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -131,7 +131,7 @@ def predict_next(self, deter_state = torch.zeros(*stoch_latent.shape[:2], self.hidden_size).to( next(self.stoch_net.parameters()).device) x = self.pre_determ_recurrent(torch.concat([stoch_latent, action], dim=2)) - _, determ = self.determ_recurrent(x, h_0=deter_state) + _, determ = self.determ_recurrent(x, deter_state) # used for KL divergence predicted_stoch_latent = self.estimate_stochastic_latent(determ) @@ -290,9 +290,10 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence Dist = torch.distributions.OneHotCategoricalStraightThrough - return self.kl_beta * ( - self.alpha * KL_(dist1, Dist(logits=dist2.logits.detach())).mean() + - (1 - self.alpha) * KL_(Dist(logits=dist1.logits.detach()), dist2).mean()) + one = torch.ones(1,device=next(self.parameters()).device) + kl_lhs = torch.maximum(KL_(Dist(logits=dist2.logits.detach()), dist1), one) + kl_rhs = torch.maximum(KL_(dist2, Dist(logits=dist1.logits.detach())), one) + return self.kl_beta * (self.alpha * kl_lhs.mean()+ (1 - self.alpha) * kl_rhs.mean()) latent_vars = [] @@ -311,7 +312,7 @@ def KL(dist1, dist2): f_t_pred = self.discount_predictor( torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - x_r = self.image_predictor([determ_t.squeeze(0), posterior_stoch]) + x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) losses['loss_reconstruction'] = nn.functional.mse_loss(x_t, x_r) losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) @@ -409,6 +410,8 @@ def __init__( self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing).to(device_type) + # TODO: final activation should depend whether agent + # action space in one hot or identity if real-valued self.actor = fc_nn_generator(latent_dim * latent_classes, actions_num, 400, @@ -440,7 +443,7 @@ def imagine_trajectory( ) -> tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, torch.Tensor, torch.Tensor]: world_state = None - zs, actions, next_zs, rewards, ts = [], [], [], [], [] + zs, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [] z = z_0.detach() for _ in range(self.imagination_horizon): a = self.actor(z) @@ -452,17 +455,18 @@ def imagine_trajectory( next_zs.append(next_z) rewards.append(reward) ts.append(is_finished) + determs.append(world_state[0]) z = next_z.detach() return (torch.stack(zs), actions, torch.stack(next_zs), - torch.stack(rewards).detach(), torch.stack(ts).detach()) + torch.stack(rewards).detach(), torch.stack(ts).detach(), torch.stack(determs)) def reset(self): self._state = None self._last_action = torch.zeros((1, 1, self.actions_num), device=next(self.world_model.parameters()).device) - self._latent_probs = torch.zeros((32, 32)) - self._action_probs = torch.zeros((self.actions_num)) + self._latent_probs = torch.zeros((32, 32), device=next(self.world_model.parameters()).device) + self._action_probs = torch.zeros((self.actions_num), device=next(self.world_model.parameters()).device) self._stored_steps = 0 def preprocess_obs(self, obs: torch.Tensor): @@ -500,11 +504,8 @@ def _generate_video(self, obs: Observation, init_action: Action): action.unsqueeze(0).unsqueeze(0), None)[1].rsample().reshape(-1, 32 * 32).unsqueeze(0) - zs, _, _, _, _ = self.imagine_trajectory(z_0) - reconstructed_plan = [ - self.world_model.image_predictor(z).detach().numpy() for z in zs - ] - video_r = np.concatenate(reconstructed_plan) + zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0)) + video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) return video_r @@ -532,6 +533,7 @@ def viz_log(self, rollout, logger, epoch_num): ax.bar(np.arange(self.actions_num), action_hist) logger.add_figure('val/action_probs', fig, epoch_num) logger.add_image('val/latent_probs', latent_hist, epoch_num) + logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=2), epoch_num) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) def from_np(self, arr: np.ndarray): @@ -566,11 +568,11 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac = defaultdict( lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) - zs, action_dists, next_zs, rewards, terminal_flags = self.imagine_trajectory( + zs, action_dists, next_zs, rewards, terminal_flags, _ = self.imagine_trajectory( initial_states) vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) - losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs), + losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), vs.detach(), reduction='sum') / zs.shape[1] losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control diff --git a/rl_sandbox/agents/explorative_agent.py b/rl_sandbox/agents/explorative_agent.py new file mode 100644 index 0000000..c26ebca --- /dev/null +++ b/rl_sandbox/agents/explorative_agent.py @@ -0,0 +1,28 @@ +import numpy as np +from nptyping import Float, NDArray, Shape + +from rl_sandbox.agents.rl_agent import RlAgent +from rl_sandbox.utils.schedulers import Scheduler +from rl_sandbox.utils.replay_buffer import (Action, Actions, Rewards, State, + States, TerminationFlags) + + +class ExplorativeAgent(RlAgent): + def __init__(self, policy_agent: RlAgent, + exploration_agent: RlAgent, + scheduler: Scheduler): + self.policy_ag = policy_agent + self.expl_ag = exploration_agent + self.scheduler = scheduler + + def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: + if np.random.random() > self.scheduler.step(): + return self.expl_ag.get_action(obs) + return self.policy_ag.get_action(obs) + + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): + return self.expl_ag.train(s, a, r, next, is_finished) | self.policy_ag.train(s, a, r, next, is_finished) + + def save_ckpt(self, epoch_num: int, losses: dict[str, float]): + self.policy_ag.save_ckpt(epoch_num, losses) + self.expl_ag.save_ckpt(epoch_num, losses) diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py index eab3d70..1638a93 100644 --- a/rl_sandbox/agents/random_agent.py +++ b/rl_sandbox/agents/random_agent.py @@ -15,7 +15,7 @@ def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: return self.action_space.sample() def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): - pass + return dict() def save_ckpt(self, epoch_num: int, losses: dict[str, float]): pass diff --git a/rl_sandbox/agents/rl_agent.py b/rl_sandbox/agents/rl_agent.py index 3f0ba09..0080fb5 100644 --- a/rl_sandbox/agents/rl_agent.py +++ b/rl_sandbox/agents/rl_agent.py @@ -1,7 +1,7 @@ from typing import Any from abc import ABCMeta, abstractmethod -from rl_sandbox.utils.replay_buffer import Action, State, States, Actions, Rewards +from rl_sandbox.utils.replay_buffer import Action, State, States, Actions, Rewards, TerminationFlags class RlAgent(metaclass=ABCMeta): @abstractmethod @@ -9,7 +9,7 @@ def get_action(self, obs: State) -> Action: pass @abstractmethod - def train(self, s: States, a: Actions, r: Rewards, next: States) -> dict[str, Any]: + def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags) -> dict[str, Any]: """ Return dict with losses for logging """ diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index bd5fbe8..2b8fe93 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -98,15 +98,12 @@ def __init__(self, run_on_pixels: bool, obs_res: tuple[int, int], self.ac_trans.append(t) def step(self, action: Action) -> EnvStepResult: - action = action.copy() for t in reversed(self.ac_trans): action = t.transform_action(action) - for _ in range(self.repeat_action_num): - res = self._step(action) - return res + return self._step(action, self.repeat_action_num) @abstractmethod - def _step(self, action: Action) -> EnvStepResult: + def _step(self, action: Action, repeat_num: int = 1) -> EnvStepResult: pass @abstractmethod @@ -170,7 +167,7 @@ def __init__(self, run_on_pixels: bool, obs_res: tuple[int, super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) def render(self): - return self.env.physics.render(*self.obs_res, camera_id=0) + return self.env.physics.render(*self.obs_res) def _uncode_ts(self, ts: TimeStep) -> EnvStepResult: if self.run_on_pixels: @@ -180,9 +177,19 @@ def _uncode_ts(self, ts: TimeStep) -> EnvStepResult: state = np.concatenate([state[s] for s in state], dtype=np.float32) return EnvStepResult(state, ts.reward, ts.last()) - def _step(self, action: Action) -> EnvStepResult: - # TODO: add action repeat to speed up DMC simulations - return self._uncode_ts(self.env.step(action)) + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + rew = 0 + for _ in range(repeat_num - 1): + ts = self.env.step(action) + rew += ts.reward or 0.0 + if ts.last(): + break + if repeat_num == 1 or not ts.last(): + env_res = self._uncode_ts(self.env.step(action)) + else: + env_res = ts + env_res.reward = np.tanh(rew + (env_res.reward or 0.0)) + return env_res def reset(self) -> EnvStepResult: return self._uncode_ts(self.env.reset()) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index ea17909..2e692d8 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -29,7 +29,7 @@ class Rollout: # TODO: make buffer concurrent-friendly class ReplayBuffer: - def __init__(self, max_len=2_000): + def __init__(self, max_len=2e5): self.rollouts: deque[Rollout] = deque() self.rollouts_len: deque[int] = deque() self.curr_rollout = None @@ -83,24 +83,35 @@ def sample( seq_num = batch_size // cluster_size # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me s, a, r, n, t = [], [], [], [], [] - r_indeces = np.random.choice(len(self.rollouts), + do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > cluster_size + r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), seq_num, - p=np.array(self.rollouts_len) / self.total_num) + p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / (self.total_num + int(do_add_curr)*len(self.curr_rollout.states))) + s_indeces = [] for r_idx in r_indeces: - # NOTE: maybe just no add such small rollouts to buffer - assert self.rollouts_len[r_idx] - cluster_size + 1 > 0, "Rollout it too small" - s_idx = np.random.choice(self.rollouts_len[r_idx] - cluster_size + 1, 1).item() - - s.append(self.rollouts[r_idx].states[s_idx:s_idx + cluster_size]) - a.append(self.rollouts[r_idx].actions[s_idx:s_idx + cluster_size]) - r.append(self.rollouts[r_idx].rewards[s_idx:s_idx + cluster_size]) - t.append(self.rollouts[r_idx].is_finished[s_idx:s_idx + cluster_size]) - if s_idx != self.rollouts_len[r_idx] - cluster_size: - n.append(self.rollouts[r_idx].states[s_idx+1:s_idx+1 + cluster_size]) + if r_idx != len(self.rollouts): + rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] + else: + # -1 because we don't have next_state on terminal + rollout, r_len = self.curr_rollout, len(self.curr_rollout.states) - 1 + + # NOTE: maybe just not add such small rollouts to buffer + assert r_len > cluster_size - 1, "Rollout it too small" + s_idx = np.random.choice(r_len - cluster_size + 1, 1).item() + s_indeces.append(s_idx) + + if r_idx == len(self.rollouts): + r_len += 1 + + s.append(rollout.states[s_idx:s_idx + cluster_size]) + a.append(rollout.actions[s_idx:s_idx + cluster_size]) + r.append(rollout.rewards[s_idx:s_idx + cluster_size]) + t.append(rollout.is_finished[s_idx:s_idx + cluster_size]) + if s_idx != r_len - cluster_size: + n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size]) else: if cluster_size != 1: - n.append(self.rollouts[r_idx].states[s_idx+1:s_idx+1 + cluster_size - 1]) - n.append(self.rollouts[r_idx].next_states) - - return (np.concatenate(s), np.concatenate(a), np.concatenate(r), + n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size - 1]) + n.append(rollout.next_states) + return (np.concatenate(s), np.concatenate(a), np.concatenate(r, dtype=np.float32), np.concatenate(n), np.concatenate(t)) diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index cde5f21..b787c47 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -4,29 +4,38 @@ from unpackable import unpack from rl_sandbox.agents.random_agent import RandomAgent +from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout +from rl_sandbox.utils.replay_buffer import (Action, State, Observation) +def iter_rollout(env: Env, + agent: RlAgent, + collect_obs: bool = False) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, None]: + state, _, terminated = unpack(env.reset()) + agent.reset() + + while not terminated: + action = agent.get_action(state) + + new_state, reward, terminated = unpack(env.step(action)) + + # FIXME: will break for non-DM + obs = env.render() if collect_obs else None + # if collect_obs and isinstance(env, dmEnv): + yield state, action, reward, new_state, terminated, obs + state = new_state -# FIXME: whole function duplicates a lot of code from main.py def collect_rollout(env: Env, - agent: t.Optional[t.Any] = None, + agent: t.Optional[RlAgent] = None, collect_obs: bool = False ) -> Rollout: - s, a, r, n, f, o = [], [], [], [], [], [] if agent is None: agent = RandomAgent(env) - state, _, terminated = unpack(env.reset()) - agent.reset() - - while not terminated: - action = agent.get_action(state) - - new_state, reward, terminated = unpack(env.step(action)) - + for state, action, reward, new_state, terminated, obs in iter_rollout(env, agent, collect_obs): s.append(state) a.append(action) r.append(reward) @@ -35,11 +44,8 @@ def collect_rollout(env: Env, # FIXME: will break for non-DM if collect_obs: - o.append(env.render()) - # if collect_obs and isinstance(env, dmEnv): - state = new_state + o.append(obs) - obs = None # match env: # case gym.Env(): # obs = np.stack(list(env.render())) if obs_res is not None else None From 59325946d9b198a1032f136bcc111ea404ebf9b8 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 6 Dec 2022 13:47:07 +0000 Subject: [PATCH 019/106] Fix reconstruction & improve hist viz --- rl_sandbox/agents/dreamer_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 8b5c983..bcfb8be 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -314,7 +314,7 @@ def KL(dist1, dist2): x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - losses['loss_reconstruction'] = nn.functional.mse_loss(x_t, x_r) + losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), f_t_pred) @@ -523,8 +523,8 @@ def viz_log(self, rollout, logger, epoch_num): ], axis=3) videos_comparison = np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0) - latent_hist = (self._latent_probs / self._stored_steps * - 255.0).detach().cpu().numpy().reshape(-1, 32, 32) + latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() + latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) action_hist = (self._action_probs / self._stored_steps).detach().cpu().numpy() # logger.add_histogram('val/action_probs', action_hist, epoch_num) @@ -532,8 +532,8 @@ def viz_log(self, rollout, logger, epoch_num): ax = fig.add_axes([0, 0, 1, 1]) ax.bar(np.arange(self.actions_num), action_hist) logger.add_figure('val/action_probs', fig, epoch_num) - logger.add_image('val/latent_probs', latent_hist, epoch_num) - logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=2), epoch_num) + logger.add_image('val/latent_probs', latent_hist, epoch_num, dataformats='HW') + logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) def from_np(self, arr: np.ndarray): From 6bfb34bd45782d8792d7f85d4cec534152efc52d Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 7 Dec 2022 09:05:13 +0000 Subject: [PATCH 020/106] Add PersistentReplayBuffer based on WebDataset --- main.py | 8 +- pyproject.toml | 1 + rl_sandbox/agents/dreamer_v2.py | 7 +- rl_sandbox/utils/persistent_replay_buffer.py | 110 +++++++++++++++++++ rl_sandbox/utils/replay_buffer.py | 2 + 5 files changed, 123 insertions(+), 5 deletions(-) create mode 100644 rl_sandbox/utils/persistent_replay_buffer.py diff --git a/main.py b/main.py index a5c132c..2568987 100644 --- a/main.py +++ b/main.py @@ -3,13 +3,14 @@ from omegaconf import DictConfig, OmegaConf from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm -from unpackable import unpack +from pathlib import Path from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.agents.explorative_agent import ExplorativeAgent from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env -from rl_sandbox.utils.replay_buffer import ReplayBuffer +# from rl_sandbox.utils.replay_buffer import ReplayBuffer +from rl_sandbox.utils.persistent_replay_buffer import PersistentReplayBuffer from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, fillup_replay_buffer) from rl_sandbox.utils.schedulers import LinearScheduler @@ -23,7 +24,7 @@ def main(cfg: DictConfig): # TODO: add replay buffer implementation, which stores rollouts # on disk - buff = ReplayBuffer() + buff = PersistentReplayBuffer(Path('rollouts/')) fillup_replay_buffer(env, buff, cfg.training.batch_size) metrics_evaluator = MetricsEvaluator() @@ -68,6 +69,7 @@ def main(cfg: DictConfig): ### Validation if epoch_num % cfg.training.val_logs_every == 0: rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) + # TODO: make logs visualization in separate process metrics = metrics_evaluator.calculate_metrics(rollouts) for metric_name, metric in metrics.items(): writer.add_scalar(f'val/{metric_name}', metric, epoch_num) diff --git a/pyproject.toml b/pyproject.toml index 3e22666..d16679d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,3 +27,4 @@ dm-control = '^1.0.0' unpackable = '^0.0.4' hydra-core = "^1.2.0" matplotlib = "^3.0.0" +webdataset = "^0.2.20" diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index bcfb8be..ad1e9db 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -135,7 +135,7 @@ def predict_next(self, # used for KL divergence predicted_stoch_latent = self.estimate_stochastic_latent(determ) - return deter_state, predicted_stoch_latent + return determ, predicted_stoch_latent def update_current(self, determ, embed): # Dreamer 'obs_out' return self.stoch_net(torch.concat([determ, embed], dim=2)) @@ -318,6 +318,7 @@ def KL(dist1, dist2): losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), f_t_pred) + # NOTE: entropy can be added as metric losses['loss_kl_reg'] += KL(prior_stoch_dist, posterior_stoch_dist) h_prev = [determ_t, posterior_stoch.unsqueeze(0)] @@ -470,6 +471,7 @@ def reset(self): self._stored_steps = 0 def preprocess_obs(self, obs: torch.Tensor): + # FIXME: move to dataloader in replay buffer order = list(range(len(obs.shape))) # Swap channel from last to 3 from last order = order[:-3] + [order[-1]] + order[-3:-1] @@ -537,7 +539,8 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) def from_np(self, arr: np.ndarray): - return torch.from_numpy(arr).to(next(self.world_model.parameters()).device) + arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr + return arr.to(next(self.world_model.parameters()).device, non_blocking=True) def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observations, is_finished: TerminationFlags): diff --git a/rl_sandbox/utils/persistent_replay_buffer.py b/rl_sandbox/utils/persistent_replay_buffer.py new file mode 100644 index 0000000..bacf666 --- /dev/null +++ b/rl_sandbox/utils/persistent_replay_buffer.py @@ -0,0 +1,110 @@ +import typing as t +from collections import deque +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import webdataset as wds + +from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, + Observations, Rewards, Rollout, + State, States, TerminationFlags) + + +# TODO: add tagging of replay buffer meta-data (env config) +# to omit incompatible cache +class PersistentReplayBuffer: + + def __init__(self, directory: Path, max_len=1e6): + self.max_len: int = int(max_len) + self.directory = directory + self.directory.mkdir(exist_ok=True) + self.rollouts: list[str] = list(map(str, self.directory.glob('*.tar'))) + self.rollouts_num = len(self.rollouts) + # FIXME: add correct length calculation, currently hardcoded + self.rollouts_len: list[int] = [200] * self.rollouts_num + self.total_num = sum(self.rollouts_len) + self.rollout_idx = self.rollouts_num + + self.curr_rollout: t.Optional[Rollout] = None + self.rollouts_changed: bool = True + + def add_rollout(self, rollout: Rollout): + name = str(self.directory / f'rollout-{self.rollout_idx % self.max_len}.tar') + sink = wds.TarWriter(name) + + for idx in range(len(rollout)): + s, a, r, t = rollout.states[idx], rollout.actions[idx], rollout.rewards[ + idx], rollout.is_finished[idx] + sink.write({ + "__key__": "sample%06d" % idx, + "state.pyd": s, + "action.pyd": a, + "reward.pyd": np.array(r, dtype=np.float32), + "is_finished.pyd": np.array(t, dtype=np.bool_) + }) + + if self.rollout_idx < self.max_len: + self.total_num += len(rollout) + self.rollouts_num += 1 + self.rollouts.append(name) + self.rollouts_len.append(len(rollout)) + else: + self.total_num += len(rollout) - self.rollouts_len[self.rollout_idx % + self.max_len] + self.rollouts[self.rollout_idx % self.max_len] = name + self.rollouts_len[self.rollout_idx % self.max_len] = len(rollout) + self.rollout_idx += 1 + self.rollouts_changed = True + + # Add sample expects that each subsequent sample + # will be continuation of last rollout util termination flag true + # is encountered + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + if self.curr_rollout is None: + self.curr_rollout = Rollout([s], [a], [r], None, [f]) + else: + self.curr_rollout.states.append(s) + self.curr_rollout.actions.append(a) + self.curr_rollout.rewards.append(r) + self.curr_rollout.is_finished.append(f) + + if f: + self.add_rollout( + Rollout(np.array(self.curr_rollout.states), + np.array(self.curr_rollout.actions), + np.array(self.curr_rollout.rewards, dtype=np.float32), + np.array([n]), np.array(self.curr_rollout.is_finished))) + self.curr_rollout = None + + def can_sample(self, num: int): + return self.total_num >= num + + @staticmethod + def add_next(src): + s, a, r, t = src + return s[:-1], a[:-1], r[:-1], s[1:], t[:-1] + + def sample( + self, + batch_size: int, + cluster_size: int = 1 + ) -> tuple[States, Actions, Rewards, States, TerminationFlags]: + seq_num = batch_size // cluster_size + # TODO: Could be done in async before + # NOTE: maybe use WDS_REWRITE + + if self.rollouts_changed: + # NOTE: shardshuffle will specify amount of urls that will be taken + # into account. Sorting not everything doesn't make sense + self.dataset = wds.WebDataset(self.rollouts + ).decode().to_tuple("state.pyd", "action.pyd", "reward.pyd", "is_finished.pyd" + # NOTE: does not take into account is_finished + ).batched(cluster_size + 1, partial=False + ).map(self.add_next).batched(seq_num) + # NOTE: in WebDataset github, it is recommended to use such batching by ourselves + # https://github.com/webdataset/webdataset#dataloader + self.loader = iter( + wds.WebLoader(self.dataset, batch_size=None, + num_workers=4, pin_memory=True).unbatched().shuffle(1000).unbatched().batched(batch_size)) + return next(self.loader) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 2e692d8..ae83ca7 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -25,6 +25,8 @@ class Rollout: is_finished: TerminationFlags observations: t.Optional[Observations] = None + def __len__(self): + return len(self.states) # TODO: make buffer concurrent-friendly class ReplayBuffer: From c88cd045c346f33176bed8fb087789c5a176a85f Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 9 Dec 2022 01:58:50 +0000 Subject: [PATCH 021/106] add async iteration --- main.py | 2 +- rl_sandbox/agents/dreamer_v2.py | 3 +- rl_sandbox/utils/rollout_generation.py | 65 +++++++++++++++++++++----- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index 2568987..b89c801 100644 --- a/main.py +++ b/main.py @@ -51,7 +51,7 @@ def main(cfg: DictConfig): ### Training and exploration # TODO: add buffer end prioritarization - for s, a, r, n, f, _ in iter_rollout(env, agent): + for s, a, r, n, f, _ in iter_rollout_async(env, agent): buff.add_sample(s, a, r, n, f) if global_step % cfg.training.gradient_steps_per_step == 0: diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index ad1e9db..d54a007 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -470,7 +470,8 @@ def reset(self): self._action_probs = torch.zeros((self.actions_num), device=next(self.world_model.parameters()).device) self._stored_steps = 0 - def preprocess_obs(self, obs: torch.Tensor): + @staticmethod + def preprocess_obs(obs: torch.Tensor): # FIXME: move to dataloader in replay buffer order = list(range(len(obs.shape))) # Swap channel from last to 3 from last diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index b787c47..a5446f1 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -1,17 +1,56 @@ import typing as t +from multiprocessing.synchronize import Lock import numpy as np +import torch.multiprocessing as mp from unpackable import unpack from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.env import Env -from rl_sandbox.utils.replay_buffer import ReplayBuffer, Rollout -from rl_sandbox.utils.replay_buffer import (Action, State, Observation) +from rl_sandbox.utils.replay_buffer import (Action, Observation, ReplayBuffer, + Rollout, State) -def iter_rollout(env: Env, - agent: RlAgent, - collect_obs: bool = False) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, None]: + +def _async_env_worker(env: Env, obs_queue: mp.Queue, act_queue: mp.Queue): + state, _, terminated = unpack(env.reset()) + obs_queue.put((state, 0, terminated), block=False) + + while not terminated: + action = act_queue.get(block=True) + + new_state, reward, terminated = unpack(env.step(action)) + del action + obs_queue.put((state, reward, terminated), block=False) + + state = new_state + + +def iter_rollout_async( + env: Env, + agent: RlAgent +) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, + None]: + # NOTE: maybe use SharedMemory instead + obs_queue = mp.Queue(1) + a_queue = mp.Queue(1) + p = mp.Process(target=_async_env_worker, args=(env, obs_queue, a_queue)) + p.start() + terminated = False + + while not terminated: + state, reward, terminated = obs_queue.get(block=True) + action = agent.get_action(state) + a_queue.put(action) + yield state, action, reward, None, terminated, state + + +def iter_rollout( + env: Env, + agent: RlAgent, + collect_obs: bool = False +) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, + None]: state, _, terminated = unpack(env.reset()) agent.reset() @@ -26,16 +65,17 @@ def iter_rollout(env: Env, yield state, action, reward, new_state, terminated, obs state = new_state + def collect_rollout(env: Env, agent: t.Optional[RlAgent] = None, - collect_obs: bool = False - ) -> Rollout: + collect_obs: bool = False) -> Rollout: s, a, r, n, f, o = [], [], [], [], [], [] if agent is None: agent = RandomAgent(env) - for state, action, reward, new_state, terminated, obs in iter_rollout(env, agent, collect_obs): + for state, action, reward, new_state, terminated, obs in iter_rollout( + env, agent, collect_obs): s.append(state) a.append(action) r.append(reward) @@ -51,7 +91,10 @@ def collect_rollout(env: Env, # obs = np.stack(list(env.render())) if obs_res is not None else None # case dmEnv(): obs = np.array(o) if collect_obs is not None else None - return Rollout(np.array(s), np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), np.array(n), np.array(f), obs) + return Rollout(np.array(s), + np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), + np.array(n), np.array(f), obs) + def collect_rollout_num(env: Env, num: int, @@ -64,9 +107,7 @@ def collect_rollout_num(env: Env, return rollouts -def fillup_replay_buffer(env: Env, - rep_buffer: ReplayBuffer, - num: int): +def fillup_replay_buffer(env: Env, rep_buffer: ReplayBuffer, num: int): # TODO: paralelyze while not rep_buffer.can_sample(num): rep_buffer.add_rollout(collect_rollout(env, collect_obs=False)) From 91e96c4eb5abe2cb9946e00f4e89433b9bb819e8 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 18 Dec 2022 21:41:31 +0000 Subject: [PATCH 022/106] Improved the performance of training - Added Pytorch Profiler - Repeat action number should be used to indicate valid env steps - Disable distribution args validation - Moved image reconstruction, reward&discount prediction out of RNN loop - Added pretrain step --- config/agent/dreamer_v2.yaml | 2 +- config/config.yaml | 15 ++- main.py | 51 +++++--- rl_sandbox/agents/dreamer_v2.py | 206 +++++++++++++++++------------- rl_sandbox/utils/replay_buffer.py | 3 +- 5 files changed, 163 insertions(+), 114 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index 806e19f..a470602 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -9,7 +9,7 @@ kl_loss_balancing: 0.8 world_model_lr: 3e-4 # ActorCritic parameters -discount_factor: 0.995 +discount_factor: 0.99 imagination_horizon: 15 actor_lr: 8e-5 diff --git a/config/config.yaml b/config/config.yaml index b6e9c26..94b1539 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,15 +4,20 @@ defaults: - _self_ seed: 42 -device_type: cpu +device_type: cuda training: - epochs: 5000 + steps: 1e6 + prefill: 1000 + pretrain: 0 batch_size: 1024 - gradient_steps_per_step: 4 - save_checkpoint_every: 1000 - val_logs_every: 5 + gradient_steps_per_step: 5 + save_checkpoint_every: 1e5 + val_logs_every: 5e3 validation: rollout_num: 5 visualize: true + +debug: + profiler: false diff --git a/main.py b/main.py index b89c801..16b963d 100644 --- a/main.py +++ b/main.py @@ -5,13 +5,16 @@ from tqdm import tqdm from pathlib import Path +import torch +from torch.profiler import profile, record_function, ProfilerActivity + from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.agents.explorative_agent import ExplorativeAgent from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env -# from rl_sandbox.utils.replay_buffer import ReplayBuffer +from rl_sandbox.utils.replay_buffer import ReplayBuffer from rl_sandbox.utils.persistent_replay_buffer import PersistentReplayBuffer -from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, +from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, iter_rollout_async, fillup_replay_buffer) from rl_sandbox.utils.schedulers import LinearScheduler @@ -19,13 +22,15 @@ @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): # print(OmegaConf.to_yaml(cfg)) + torch.distributions.Distribution.set_default_validate_args(False) + torch.backends.cudnn.benchmark = True env: Env = hydra.utils.instantiate(cfg.env) # TODO: add replay buffer implementation, which stores rollouts # on disk - buff = PersistentReplayBuffer(Path('rollouts/')) - fillup_replay_buffer(env, buff, cfg.training.batch_size) + buff = ReplayBuffer() + fillup_replay_buffer(env, buff, max(cfg.training.prefill, cfg.training.batch_size)) metrics_evaluator = MetricsEvaluator() @@ -45,13 +50,25 @@ def main(cfg: DictConfig): LinearScheduler(0.9, 0.01, 5_000)) writer = SummaryWriter() + prof = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], + on_trace_ready=torch.profiler.tensorboard_trace_handler('runs/profile_dreamer'), + schedule=torch.profiler.schedule(wait=10, warmup=10, active=5, repeat=5), + with_stack=True) if cfg.debug.profiler else None + + for i in tqdm(range(cfg.training.pretrain), desc='Pretraining'): + s, a, r, n, f = buff.sample(cfg.training.batch_size, + cluster_size=cfg.agent.get('batch_cluster_size', 1)) + losses = agent.train(s, a, r, n, f) + for loss_name, loss in losses.items(): + writer.add_scalar(f'pre_train/{loss_name}', loss, i) + global_step = 0 - pbar = tqdm(total=cfg.training.epochs*200) - for epoch_num in range(cfg.training.epochs): + pbar = tqdm(total=cfg.training.steps, desc='Training') + while global_step < cfg.training.steps: ### Training and exploration # TODO: add buffer end prioritarization - for s, a, r, n, f, _ in iter_rollout_async(env, agent): + for s, a, r, n, f, _ in iter_rollout(env, agent): buff.add_sample(s, a, r, n, f) if global_step % cfg.training.gradient_steps_per_step == 0: @@ -61,31 +78,35 @@ def main(cfg: DictConfig): cluster_size=cfg.agent.get('batch_cluster_size', 1)) losses = agent.train(s, a, r, n, f) + if cfg.debug.profiler: + prof.step() for loss_name, loss in losses.items(): writer.add_scalar(f'train/{loss_name}', loss, global_step) - global_step += 1 - pbar.update(1) + global_step += cfg.env.repeat_action_num + pbar.update(cfg.env.repeat_action_num) ### Validation - if epoch_num % cfg.training.val_logs_every == 0: + if global_step % cfg.training.val_logs_every == 0: rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) # TODO: make logs visualization in separate process metrics = metrics_evaluator.calculate_metrics(rollouts) for metric_name, metric in metrics.items(): - writer.add_scalar(f'val/{metric_name}', metric, epoch_num) + writer.add_scalar(f'val/{metric_name}', metric, global_step) if cfg.validation.visualize: rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, epoch_num) + writer.add_video('val/visualization', video, global_step) # FIXME: Very bad from architecture point - agent.policy_ag.viz_log(rollout, writer, epoch_num) + agent.policy_ag.viz_log(rollout, writer, global_step) ### Checkpoint - if epoch_num % cfg.training.save_checkpoint_every == 0: - agent.save_ckpt(epoch_num, losses) + if global_step % cfg.training.save_checkpoint_every == 0: + agent.save_ckpt(global_step, losses) + if cfg.debug.profiler: + prof.stop() if __name__ == "__main__": diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index d54a007..9ff1f44 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -7,6 +7,7 @@ import torch from torch import nn from torch.nn import functional as F +import torch.distributions as td from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator @@ -25,25 +26,15 @@ def forward(self, x): return x.view(*self.shape) -class DebugShapeLayer(nn.Module): - - def __init__(self, note=""): - super().__init__() - self.note = note - - def forward(self, x): - if len(self.note): - print(self.note, x.shape) - else: - print(x.shape) - return x - - class Quantize(nn.Module): def forward(self, logits): - return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( - logits=logits) + return logits + # return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( + # logits=logits) + +def Dist(val): + return td.OneHotCategoricalStraightThrough(logits=val) class RSSM(nn.Module): @@ -134,8 +125,8 @@ def predict_next(self, _, determ = self.determ_recurrent(x, deter_state) # used for KL divergence - predicted_stoch_latent = self.estimate_stochastic_latent(determ) - return determ, predicted_stoch_latent + predicted_stoch_logits = self.estimate_stochastic_latent(determ) + return determ, predicted_stoch_logits def update_current(self, determ, embed): # Dreamer 'obs_out' return self.stoch_net(torch.concat([determ, embed], dim=2)) @@ -161,12 +152,12 @@ def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, (*action.shape[:-1], self.latent_dim * self.latent_classes), device=next(self.stoch_net.parameters()).device)) deter_prev, stoch_prev = h_prev - determ, prior_stoch_dist = self.predict_next(stoch_prev, + determ, prior_stoch_logits = self.predict_next(stoch_prev, action, deter_state=deter_prev) - posterior_stoch_dist = self.update_current(determ, embed) + posterior_stoch_logits = self.update_current(determ, embed) - return [determ, prior_stoch_dist, posterior_stoch_dist] + return [determ, prior_stoch_logits, posterior_stoch_logits] # NOTE: residual blocks are not used inside dreamer @@ -182,8 +173,6 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) layers.append(nn.ELU(inplace=True)) - # FIXME: change to layer norm when sizes will be known - layers.append(nn.BatchNorm2d(out_channels)) in_channels = out_channels layers.append(nn.Flatten()) self.net = nn.Sequential(*layers) @@ -212,7 +201,6 @@ def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) layers.append(nn.ELU(inplace=True)) - layers.append(nn.BatchNorm2d(out_channels)) in_channels = out_channels self.net = nn.Sequential(*layers) @@ -255,10 +243,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=nn.Sigmoid) def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): - determ_state, next_repr_dist = self.recurrent_model.predict_next( + determ_state, next_repr_logits = self.recurrent_model.predict_next( latent_repr.unsqueeze(0), action.unsqueeze(0), world_state) - next_repr = next_repr_dist.rsample().reshape( + next_repr = Dist(next_repr_logits).rsample().reshape( -1, self.latent_dim * self.latent_classes) reward = self.reward_predictor( torch.concat([determ_state.squeeze(0), next_repr], dim=1)) @@ -268,65 +256,97 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor def get_latent(self, obs: torch.Tensor, action, state): embed = self.encoder(obs) - determ, _, latent_repr_dist = self.recurrent_model(state, embed.unsqueeze(0), + determ, _, latent_repr_logits = self.recurrent_model.forward(state, embed.unsqueeze(0), action) - return determ, latent_repr_dist + return determ, latent_repr_logits def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, is_finished: torch.Tensor): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) - embed = embed.view(b // self.cluster_size, self.cluster_size, -1) + embed_c = embed.view(b // self.cluster_size, self.cluster_size, -1) - obs = obs.view(-1, self.cluster_size, 3, h, w) - a = a.view(-1, self.cluster_size, self.actions_num) - r = r.view(-1, self.cluster_size, 1) - f = is_finished.view(-1, self.cluster_size, 1) + obs_c = obs.view(-1, self.cluster_size, 3, h, w) + a_c = a.view(-1, self.cluster_size, self.actions_num) + r_c = r.view(-1, self.cluster_size, 1) + f_c = is_finished.view(-1, self.cluster_size, 1) h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) - def KL(dist1, dist2): + def KL(dist1, dist2, clusterify: bool = False): KL_ = torch.distributions.kl_divergence - Dist = torch.distributions.OneHotCategoricalStraightThrough - one = torch.ones(1,device=next(self.parameters()).device) - kl_lhs = torch.maximum(KL_(Dist(logits=dist2.logits.detach()), dist1), one) - kl_rhs = torch.maximum(KL_(dist2, Dist(logits=dist1.logits.detach())), one) + one = torch.zeros(1,device=next(self.parameters()).device) + if clusterify: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one).view(-1, self.cluster_size).mean(dim=0).sum() + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one).view(-1, self.cluster_size).mean(dim=0).sum() + else: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one) return self.kl_beta * (self.alpha * kl_lhs.mean()+ (1 - self.alpha) * kl_rhs.mean()) latent_vars = [] + determ_vars = [] + prior_logits = [] + posterior_logits = [] + inps = [] + x_recovered = [] for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - x_t, embed_t, a_t, r_t, f_t = obs[:, t], embed[:, t].unsqueeze( - 0), a[:, t].unsqueeze(0), r[:, t], f[:, t] + x_t, embed_t, a_t, r_t, f_t = obs_c[:, t], embed_c[:, t].unsqueeze( + 0), a_c[:, t].unsqueeze(0), r_c[:, t], f_c[:, t] - determ_t, prior_stoch_dist, posterior_stoch_dist = self.recurrent_model( + determ_t, prior_stoch_logits, posterior_stoch_logits = self.recurrent_model.forward( h_prev, embed_t, a_t) + posterior_stoch_dist = Dist(posterior_stoch_logits) posterior_stoch = posterior_stoch_dist.rsample().reshape( -1, self.latent_dim * self.latent_classes) - r_t_pred = self.reward_predictor( - torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1).detach()) - f_t_pred = self.discount_predictor( - torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - - x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + # x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + # inps.append(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + # x_recovered.append(x_r) - losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) - losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) - losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), - f_t_pred) - # NOTE: entropy can be added as metric - losses['loss_kl_reg'] += KL(prior_stoch_dist, posterior_stoch_dist) + # losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) + # losses['loss_kl_reg'] += KL(prior_stoch_logits, posterior_stoch_logits) h_prev = [determ_t, posterior_stoch.unsqueeze(0)] - latent_vars.append(posterior_stoch.detach()) + determ_vars.append(determ_t.squeeze(0)) + latent_vars.append(posterior_stoch) - discovered_latents = torch.stack(latent_vars).reshape( - -1, self.latent_dim * self.latent_classes) - return losses, discovered_latents + prior_logits.append(prior_stoch_logits) + posterior_logits.append(posterior_stoch_logits) + + # r_t_pred = self.reward_predictor( + # torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1).detach()) + # f_t_pred = self.discount_predictor( + # torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + + # x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) + + # losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) + # losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) + # losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), + # f_t_pred) + # # NOTE: entropy can be added as metric + # losses['loss_kl_reg'] += KL(prior_stoch_logits, posterior_stoch_logits) + + + # inp = torch.concat([determ_vars.squeeze(0), posterior_stoch], dim=1) + inp = torch.concat([torch.concat(determ_vars), torch.concat(latent_vars)], dim=1) + r_pred = self.reward_predictor(inp).squeeze() + f_pred = self.discount_predictor(inp).squeeze() + x_r = self.image_predictor(inp) + + losses['loss_reconstruction'] += F.mse_loss(obs, x_r, reduction='none').view(-1, self.cluster_size, 3, h, w).mean(dim=(0, 2, 3, 4)).sum() + losses['loss_reward_pred'] += F.mse_loss(r, r_pred, reduction='none').mean(dim=0).sum() + losses['loss_discount_pred'] += F.cross_entropy(is_finished.type(torch.float32), + f_pred, reduction='none').mean(dim=0).sum() + # NOTE: entropy can be added as metric + losses['loss_kl_reg'] += KL(torch.concat(prior_logits), torch.concat(posterior_logits), True) + + return losses, torch.stack(latent_vars).reshape(-1, self.latent_dim * self.latent_classes).detach() class ImaginativeCritic(nn.Module): @@ -447,7 +467,7 @@ def imagine_trajectory( zs, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [] z = z_0.detach() for _ in range(self.imagination_horizon): - a = self.actor(z) + a = Dist(self.actor(z)) world_state, next_z, reward, is_finished = self.world_model.predict_next( z, a.rsample(), world_state) @@ -483,12 +503,13 @@ def get_action(self, obs: Observation) -> Action: obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) obs = self.preprocess_obs(obs).unsqueeze(0) - determ, latent_repr_dist = self.world_model.get_latent(obs, self._last_action, + determ, latent_repr_logits = self.world_model.get_latent(obs, self._last_action, self._state) + latent_repr_dist = Dist(latent_repr_logits) self._state = (determ, latent_repr_dist.rsample().reshape(-1, 32 * 32).unsqueeze(0)) - actor_dist = self.actor(self._state[1]) + actor_dist = Dist(self.actor(self._state[1])) self._last_action = actor_dist.rsample() self._action_probs += actor_dist.probs.squeeze() @@ -503,9 +524,9 @@ def _generate_video(self, obs: Observation, init_action: Action): action = F.one_hot(self.from_np(init_action).to(torch.int64), num_classes=self.actions_num).squeeze() - z_0 = self.world_model.get_latent(obs, + z_0 = Dist(self.world_model.get_latent(obs, action.unsqueeze(0).unsqueeze(0), - None)[1].rsample().reshape(-1, + None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0)) video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).cpu().detach().numpy() @@ -566,39 +587,40 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation world_model_loss.backward() self.world_model_optimizer.step() - idx = torch.randperm(discovered_latents.size(0)) - initial_states = discovered_latents[idx] - - losses_ac = defaultdict( - lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) - - zs, action_dists, next_zs, rewards, terminal_flags, _ = self.imagine_trajectory( - initial_states) - vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) - - losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), - vs.detach(), - reduction='sum') / zs.shape[1] - losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - losses_ac['loss_actor_dynamics_backprop'] = -( - (1 - self.rho) * vs).sum() / zs.shape[1] - losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( - [a.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] - losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ - 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] - - self.actor_optimizer.zero_grad() - self.critic_optimizer.zero_grad() - losses_ac['loss_critic'].backward() - losses_ac['loss_actor'].backward() - self.actor_optimizer.step() - self.critic_optimizer.step() - self.critic.update_target() + # idx = torch.randperm(discovered_latents.size(0)) + # initial_states = discovered_latents[idx] + + # losses_ac = defaultdict( + # lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) + + # zs, action_dists, next_zs, rewards, terminal_flags, _ = self.imagine_trajectory( + # initial_states) + # vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) + + # losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), + # vs.detach(), + # reduction='sum') / zs.shape[1] + # losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control + # losses_ac['loss_actor_dynamics_backprop'] = -( + # (1 - self.rho) * vs).sum() / zs.shape[1] + # losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( + # [a.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] + # losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ + # 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + + # self.actor_optimizer.zero_grad() + # self.critic_optimizer.zero_grad() + # losses_ac['loss_critic'].backward() + # losses_ac['loss_actor'].backward() + # self.actor_optimizer.step() + # self.critic_optimizer.step() + # self.critic.update_target() losses = {l: val.detach().cpu().item() for l, val in losses.items()} - losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} + # losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} - return losses | losses_ac + # return losses | losses_ac + return losses def save_ckpt(self, epoch_num: int, losses: dict[str, float]): torch.save( diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index ae83ca7..39b8c4b 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -86,9 +86,10 @@ def sample( # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me s, a, r, n, t = [], [], [], [], [] do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > cluster_size + tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), seq_num, - p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / (self.total_num + int(do_add_curr)*len(self.curr_rollout.states))) + p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / tot) s_indeces = [] for r_idx in r_indeces: if r_idx != len(self.rollouts): From 3c46b052b69ab48595d894a893e74762487c1ab7 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 7 Jan 2023 12:20:57 +0000 Subject: [PATCH 023/106] Added mock env, which just iterates over different colors --- config/env/mock.yaml | 5 +++++ rl_sandbox/utils/env.py | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 config/env/mock.yaml diff --git a/config/env/mock.yaml b/config/env/mock.yaml new file mode 100644 index 0000000..ec28cf3 --- /dev/null +++ b/config/env/mock.yaml @@ -0,0 +1,5 @@ +_target_: rl_sandbox.utils.env.MockEnv +run_on_pixels: true +obs_res: [64, 64] +repeat_action_num: 1 +transforms: [] diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 2b8fe93..adbcd1e 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -157,6 +157,32 @@ def _observation_space(self): def _action_space(self): return self.env.action_space +class MockEnv(Env): + + def __init__(self, run_on_pixels: bool, + obs_res: tuple[int, int], repeat_action_num: int, + transforms: list[ActionTransformer]): + super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) + self.max_steps = 255 + self.step_count = 0 + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + self.step_count += repeat_num + return EnvStepResult(self.render(), self.step_count, self.step_count >= self.max_steps) + + def reset(self): + self.step_count = 0 + return EnvStepResult(self.render(), 0, False) + + def render(self): + return np.ones(self.obs_res + (3, )) * self.step_count + + def _observation_space(self): + return gym.spaces.Box(0, 255, self.obs_res + (3, ), dtype=np.uint8) + + def _action_space(self): + return gym.spaces.Box(-1, 1, (1, ), dtype=np.float32) + class DmEnv(Env): From 90bb2a4dcd651a510ebcbbb8c4301fa39e15c440 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 6 Jan 2023 17:44:55 +0000 Subject: [PATCH 024/106] Fixed clustering of samples and correct input for world-model components Additionally: - Changed actor, reward/discount predictors to distribution output - Enabled tf32_ matmul for faster computation - Added float16 autocast --- config/agent/dreamer_v2.yaml | 1 + config/config.yaml | 7 +- config/env/dm_cartpole.yaml | 5 +- main.py | 56 ++++- rl_sandbox/agents/dreamer_v2.py | 329 ++++++++++++++++++------------ rl_sandbox/utils/env.py | 10 +- rl_sandbox/utils/fc_nn.py | 4 +- rl_sandbox/utils/replay_buffer.py | 6 +- 8 files changed, 264 insertions(+), 154 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index a470602..961aea0 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -6,6 +6,7 @@ latent_classes: 32 rssm_dim: 200 kl_loss_scale: 0.1 kl_loss_balancing: 0.8 +kl_loss_free_nats: 1.0 world_model_lr: 3e-4 # ActorCritic parameters diff --git a/config/config.yaml b/config/config.yaml index 94b1539..4cd2968 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,5 +1,6 @@ defaults: - agent/dreamer_v2 + #- env/dm_quadruped - env/dm_cartpole - _self_ @@ -7,13 +8,13 @@ seed: 42 device_type: cuda training: - steps: 1e6 + steps: 5e5 prefill: 1000 - pretrain: 0 + pretrain: 1e3 batch_size: 1024 gradient_steps_per_step: 5 save_checkpoint_every: 1e5 - val_logs_every: 5e3 + val_logs_every: 2.5e3 validation: rollout_num: 5 diff --git a/config/env/dm_cartpole.yaml b/config/env/dm_cartpole.yaml index 2fcabf9..bf2bae9 100644 --- a/config/env/dm_cartpole.yaml +++ b/config/env/dm_cartpole.yaml @@ -3,8 +3,7 @@ domain_name: cartpole task_name: swingup run_on_pixels: true obs_res: [64, 64] -repeat_action_num: 5 +camera_id: -1 +repeat_action_num: 2 transforms: - _target_: rl_sandbox.utils.env.ActionNormalizer - - _target_: rl_sandbox.utils.env.ActionDisritezer - actions_num: 10 diff --git a/main.py b/main.py index 16b963d..b973d31 100644 --- a/main.py +++ b/main.py @@ -19,11 +19,23 @@ from rl_sandbox.utils.schedulers import LinearScheduler +class SummaryWriterMock(): + def add_scalar(*args, **kwargs): + pass + + def add_video(*args, **kwargs): + pass + + def add_image(*args, **kwargs): + pass + + @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): # print(OmegaConf.to_yaml(cfg)) torch.distributions.Distribution.set_default_validate_args(False) torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True env: Env = hydra.utils.instantiate(cfg.env) @@ -37,17 +49,20 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - policy_agent = hydra.utils.instantiate(cfg.agent, + agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape[0], # FIXME: feels bad - actions_num=(env.action_space.high - env.action_space.low + 1).item(), + # actions_num=(env.action_space.high - env.action_space.low + 1).item(), + # FIXME: currently only continuous tasks + actions_num=env.action_space.shape[0], + action_type='continuous', device_type=cfg.device_type) - agent = ExplorativeAgent( - policy_agent, - # TODO: For dreamer, add noise for sampling instead - # of just random actions - RandomAgent(env), - LinearScheduler(0.9, 0.01, 5_000)) + # agent = ExplorativeAgent( + # policy_agent, + # # TODO: For dreamer, add noise for sampling instead + # # of just random actions + # RandomAgent(env), + # LinearScheduler(0.9, 0.01, 5_000)) writer = SummaryWriter() prof = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], @@ -55,13 +70,30 @@ def main(cfg: DictConfig): schedule=torch.profiler.schedule(wait=10, warmup=10, active=5, repeat=5), with_stack=True) if cfg.debug.profiler else None - for i in tqdm(range(cfg.training.pretrain), desc='Pretraining'): + for i in tqdm(range(int(cfg.training.pretrain)), desc='Pretraining'): s, a, r, n, f = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get('batch_cluster_size', 1)) losses = agent.train(s, a, r, n, f) for loss_name, loss in losses.items(): writer.add_scalar(f'pre_train/{loss_name}', loss, i) + # if cfg.training.pretrain > 0: + if i % 2e2 == 0: + rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) + # TODO: make logs visualization in separate process + metrics = metrics_evaluator.calculate_metrics(rollouts) + for metric_name, metric in metrics.items(): + writer.add_scalar(f'val/{metric_name}', metric, -10 + i/100) + + if cfg.validation.visualize: + rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) + + for rollout in rollouts: + video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) + writer.add_video('val/visualization', video, -10 + i/100) + # FIXME: Very bad from architecture point + agent.viz_log(rollout, writer, -10 + i/100) + global_step = 0 pbar = tqdm(total=cfg.training.steps, desc='Training') while global_step < cfg.training.steps: @@ -87,7 +119,8 @@ def main(cfg: DictConfig): ### Validation if global_step % cfg.training.val_logs_every == 0: - rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) + with torch.no_grad(): + rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) # TODO: make logs visualization in separate process metrics = metrics_evaluator.calculate_metrics(rollouts) for metric_name, metric in metrics.items(): @@ -100,7 +133,8 @@ def main(cfg: DictConfig): video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) writer.add_video('val/visualization', video, global_step) # FIXME: Very bad from architecture point - agent.policy_ag.viz_log(rollout, writer, global_step) + with torch.no_grad(): + agent.viz_log(rollout, writer, global_step) ### Checkpoint if global_step % cfg.training.save_checkpoint_every == 0: diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 9ff1f44..3e9447b 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -26,15 +26,63 @@ def forward(self, x): return x.view(*self.shape) -class Quantize(nn.Module): +class Sigmoid2(nn.Module): + def forward(self, x): + return 2*torch.sigmoid(x/2) + +class NormalWithOffset(nn.Module): + def __init__(self, min_std: float, std_trans: str = 'sigmoid2', transform: t.Optional[str] = None): + super().__init__() + self.min_std = min_std + match std_trans: + case 'softplus': + self.std_trans = nn.Softplus() + case 'sigmoid': + self.std_trans = nn.Sigmoid() + case 'sigmoid2': + self.std_trans = Sigmoid2() + case _: + raise RuntimeError("Unknown std transformation") + + match transform: + case 'tanh': + self.trans = [td.TanhTransform(cache_size=1)] + case None: + self.trans = None + case _: + raise RuntimeError("Unknown distribution transformation") + + def forward(self, x): + mean, std = x.chunk(2, dim=-1) + dist = td.Normal(mean, self.std_trans(std) + self.min_std) + if self.trans is None: + return dist + else: + return td.TransformedDistribution(dist, self.trans) - def forward(self, logits): - return logits - # return torch.distributions.one_hot_categorical.OneHotCategoricalStraightThrough( - # logits=logits) + +class DistLayer(nn.Module): + def __init__(self, type: str): + super().__init__() + match type: + case 'mse': + self.dist = lambda x: td.Normal(x, 1.0) + case 'normal': + self.dist = NormalWithOffset(min_std=0.1) + case 'onehot': + self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x) + case 'normal_tanh': + self.dist = NormalWithOffset(min_std=0.1, transform='tanh') + case 'binary': + self.dist = lambda x: td.Bernoulli(logits=x) + case _: + raise RuntimeError("Invalid dist layer") + + def forward(self, x): + return td.Independent(self.dist(x), 1) def Dist(val): - return td.OneHotCategoricalStraightThrough(logits=val) + return DistLayer('onehot')(val) class RSSM(nn.Module): @@ -87,8 +135,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' - View((-1, latent_dim, self.latent_classes)), - Quantize()) for _ in range(self.ensemble_num) + View((-1, latent_dim, self.latent_classes))) for _ in range(self.ensemble_num) ]) # For observation we do not have ensemble @@ -100,10 +147,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' - View((-1, latent_dim, self.latent_classes)), - # NOTE: Maybe worth having some LogSoftMax as activation - # before using input as logits for distribution - Quantize()) + View((-1, latent_dim, self.latent_classes))) def estimate_stochastic_latent(self, prev_determ): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] @@ -111,7 +155,7 @@ def estimate_stochastic_latent(self, prev_determ): # taking only one random between all ensembles # FIXME: temporary use the same model idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[0] + return dists_per_model[idx] def predict_next(self, stoch_latent, @@ -144,7 +188,7 @@ def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, # Move outside of rssm to omit checking if h_prev is None: h_prev = (torch.zeros(( - *action.shape[:-1], + *embed.shape[:-1], self.hidden_size, ), device=next(self.stoch_net.parameters()).device), @@ -172,6 +216,7 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) + layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels layers.append(nn.Flatten()) @@ -197,6 +242,7 @@ def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): out_channels = 3 layers.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) else: + layers.append(nn.BatchNorm2d(in_channels)) layers.append( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) @@ -208,13 +254,15 @@ def forward(self, X): x = self.convin(X) x = x.view(-1, 32 * self.channel_step, 1, 1) return self.net(x) + # return td.Independent(td.Normal(self.net(x), 1.0), 3) class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing): + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats): super().__init__() + self.kl_free_nats = kl_free_nats self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim self.latent_dim = latent_dim @@ -234,13 +282,14 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, 1, hidden_size=400, num_layers=4, - intermediate_activation=nn.ELU) + intermediate_activation=nn.ELU, + final_activation=DistLayer('mse')) self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, num_layers=4, intermediate_activation=nn.ELU, - final_activation=nn.Sigmoid) + final_activation=DistLayer('binary')) def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): determ_state, next_repr_logits = self.recurrent_model.predict_next( @@ -249,9 +298,8 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor next_repr = Dist(next_repr_logits).rsample().reshape( -1, self.latent_dim * self.latent_classes) reward = self.reward_predictor( - torch.concat([determ_state.squeeze(0), next_repr], dim=1)) - is_finished = self.discount_predictor( - torch.concat([determ_state.squeeze(0), next_repr], dim=1)) + torch.concat([determ_state.squeeze(0), next_repr], dim=1)).rsample() + is_finished = self.discount_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)).sample() return determ_state, next_repr, reward, is_finished def get_latent(self, obs: torch.Tensor, action, state): @@ -265,51 +313,51 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) - embed_c = embed.view(b // self.cluster_size, self.cluster_size, -1) + embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) - obs_c = obs.view(-1, self.cluster_size, 3, h, w) - a_c = a.view(-1, self.cluster_size, self.actions_num) - r_c = r.view(-1, self.cluster_size, 1) - f_c = is_finished.view(-1, self.cluster_size, 1) + obs_c = obs.reshape(-1, self.cluster_size, 3, h, w) + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + f_c = is_finished.reshape(-1, self.cluster_size, 1) h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) def KL(dist1, dist2, clusterify: bool = False): KL_ = torch.distributions.kl_divergence - one = torch.zeros(1,device=next(self.parameters()).device) + one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) if clusterify: kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one).view(-1, self.cluster_size).mean(dim=0).sum() kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one).view(-1, self.cluster_size).mean(dim=0).sum() else: - kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one) - kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one) - return self.kl_beta * (self.alpha * kl_lhs.mean()+ (1 - self.alpha) * kl_rhs.mean()) + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one).mean() + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one).mean() + return self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs) latent_vars = [] determ_vars = [] prior_logits = [] posterior_logits = [] - inps = [] - x_recovered = [] + + # inps = [] + # reconstructed = [] for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 x_t, embed_t, a_t, r_t, f_t = obs_c[:, t], embed_c[:, t].unsqueeze( - 0), a_c[:, t].unsqueeze(0), r_c[:, t], f_c[:, t] + 0), a_c[:, t].unsqueeze(0), r_c[:, t], f_c[:, t] determ_t, prior_stoch_logits, posterior_stoch_logits = self.recurrent_model.forward( h_prev, embed_t, a_t) - posterior_stoch_dist = Dist(posterior_stoch_logits) - posterior_stoch = posterior_stoch_dist.rsample().reshape( + posterior_stoch = Dist(posterior_stoch_logits).rsample().reshape( -1, self.latent_dim * self.latent_classes) - # x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - # inps.append(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - # x_recovered.append(x_r) - - # losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) - # losses['loss_kl_reg'] += KL(prior_stoch_logits, posterior_stoch_logits) + # inp_t = torch.concat([determ_t.squeeze(0), posterior_stoch], dim=-1) + # x_r = self.image_predictor(inp_t) + # inps.append(inp_t) + # reconstructed.append(x_r) + #losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) + #losses['loss_kl_reg'] += KL(prior_stoch_logits, posterior_stoch_logits) h_prev = [determ_t, posterior_stoch.unsqueeze(0)] determ_vars.append(determ_t.squeeze(0)) @@ -318,35 +366,23 @@ def KL(dist1, dist2, clusterify: bool = False): prior_logits.append(prior_stoch_logits) posterior_logits.append(posterior_stoch_logits) - # r_t_pred = self.reward_predictor( - # torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1).detach()) - # f_t_pred = self.discount_predictor( - # torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - - # x_r = self.image_predictor(torch.concat([determ_t.squeeze(0), posterior_stoch], dim=1)) - - # losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) - # losses['loss_reward_pred'] += F.mse_loss(r_t, r_t_pred) - # losses['loss_discount_pred'] += F.cross_entropy(f_t.type(torch.float32), - # f_t_pred) - # # NOTE: entropy can be added as metric - # losses['loss_kl_reg'] += KL(prior_stoch_logits, posterior_stoch_logits) - - # inp = torch.concat([determ_vars.squeeze(0), posterior_stoch], dim=1) - inp = torch.concat([torch.concat(determ_vars), torch.concat(latent_vars)], dim=1) - r_pred = self.reward_predictor(inp).squeeze() - f_pred = self.discount_predictor(inp).squeeze() - x_r = self.image_predictor(inp) - - losses['loss_reconstruction'] += F.mse_loss(obs, x_r, reduction='none').view(-1, self.cluster_size, 3, h, w).mean(dim=(0, 2, 3, 4)).sum() - losses['loss_reward_pred'] += F.mse_loss(r, r_pred, reduction='none').mean(dim=0).sum() - losses['loss_discount_pred'] += F.cross_entropy(is_finished.type(torch.float32), - f_pred, reduction='none').mean(dim=0).sum() + inp = torch.concat([torch.stack(determ_vars, dim=1), torch.stack(latent_vars, dim=1)], dim=-1) + r_pred = self.reward_predictor(inp) + f_pred = self.discount_predictor(inp) + x_r = self.image_predictor(torch.flatten(inp, 0, 1)) + + losses['loss_reconstruction'] += F.mse_loss(x_r, obs) * 32 + #losses['loss_reward_pred'] += F.mse_loss(r, r_pred) + #losses['loss_discount_pred'] += F.binary_cross_entropy_with_logits(is_finished.type(torch.float32), f_pred) + # losses['loss_reconstruction'] += -x_r.log_prob(obs).mean() + losses['loss_reward_pred'] += -r_pred.log_prob(r_c).mean() + losses['loss_discount_pred'] += -f_pred.log_prob(f_c.type(torch.float32)).mean() # NOTE: entropy can be added as metric - losses['loss_kl_reg'] += KL(torch.concat(prior_logits), torch.concat(posterior_logits), True) + losses['loss_kl_reg'] += KL(torch.flatten(torch.stack(prior_logits, dim=1), 0, 1), + torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1), False) - return losses, torch.stack(latent_vars).reshape(-1, self.latent_dim * self.latent_classes).detach() + return losses, torch.stack(latent_vars, dim=1).reshape(-1, self.latent_dim * self.latent_classes).detach() class ImaginativeCritic(nn.Module): @@ -401,6 +437,7 @@ def __init__( self, obs_space_num: int, # NOTE: encoder/decoder will work only with 64x64 currently actions_num: int, + action_type: str, batch_cluster_size: int, latent_dim: int, latent_classes: int, @@ -408,6 +445,7 @@ def __init__( discount_factor: float, kl_loss_scale: float, kl_loss_balancing: float, + kl_loss_free_nats: float, imagination_horizon: int, critic_update_interval: int, actor_reinforce_fraction: float, @@ -419,7 +457,6 @@ def __init__( critic_lr: float, device_type: str = 'cpu'): - self.actions_num = actions_num self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size self.actions_num = actions_num @@ -430,21 +467,22 @@ def __init__( self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, - kl_loss_balancing).to(device_type) - # TODO: final activation should depend whether agent - # action space in one hot or identity if real-valued + kl_loss_balancing, kl_loss_free_nats).to(device_type) + self.actor = fc_nn_generator(latent_dim * latent_classes, - actions_num, + actions_num * 2 if action_type == 'continuous' else actions_num, 400, 4, intermediate_activation=nn.ELU, - final_activation=Quantize).to(device_type) + final_activation=DistLayer('normal_tanh' if action_type == 'continuous' else 'onehot')).to(device_type) + self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, critic_value_target_lambda, latent_dim * latent_classes, actions_num).to(device_type) + self.scaler = torch.cuda.amp.GradScaler() self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), lr=world_model_lr, eps=1e-5, @@ -462,12 +500,12 @@ def __init__( def imagine_trajectory( self, z_0 ) -> tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, torch.Tensor, - torch.Tensor]: + torch.Tensor, torch.Tensor]: world_state = None zs, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [] z = z_0.detach() for _ in range(self.imagination_horizon): - a = Dist(self.actor(z)) + a = self.actor(z) world_state, next_z, reward, is_finished = self.world_model.predict_next( z, a.rsample(), world_state) @@ -497,6 +535,7 @@ def preprocess_obs(obs: torch.Tensor): # Swap channel from last to 3 from last order = order[:-3] + [order[-1]] + order[-3:-1] return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) + # return obs.type(torch.float32).permute(order) def get_action(self, obs: Observation) -> Action: # NOTE: pytorch fails without .copy() only when get_action is called @@ -509,26 +548,31 @@ def get_action(self, obs: Observation) -> Action: self._state = (determ, latent_repr_dist.rsample().reshape(-1, 32 * 32).unsqueeze(0)) - actor_dist = Dist(self.actor(self._state[1])) + actor_dist = self.actor(self._state[1]) self._last_action = actor_dist.rsample() - self._action_probs += actor_dist.probs.squeeze() - self._latent_probs += latent_repr_dist.probs.squeeze() + if False: + self._action_probs += actor_dist.base_dist.probs.squeeze() + self._latent_probs += latent_repr_dist.base_dist.probs.squeeze() self._stored_steps += 1 - return np.array([self._last_action.squeeze().detach().cpu().numpy().argmax()]) + if False: + return np.array([self._last_action.squeeze().detach().cpu().numpy().argmax()]) + else: + return self._last_action.squeeze().detach().cpu().numpy() def _generate_video(self, obs: Observation, init_action: Action): obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) obs = self.preprocess_obs(obs).unsqueeze(0) - action = F.one_hot(self.from_np(init_action).to(torch.int64), - num_classes=self.actions_num).squeeze() - z_0 = Dist(self.world_model.get_latent(obs, - action.unsqueeze(0).unsqueeze(0), - None)[1]).rsample().reshape(-1, - 32 * 32).unsqueeze(0) + if False: + action = F.one_hot(self.from_np(init_action).to(torch.int64), + num_classes=self.actions_num).squeeze() + else: + action = self.from_np(init_action) + z_0 = Dist(self.world_model.get_latent(obs, action.unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0)) + # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) return video_r @@ -549,13 +593,17 @@ def viz_log(self, rollout, logger, epoch_num): videos_comparison = np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) - action_hist = (self._action_probs / self._stored_steps).detach().cpu().numpy() - # logger.add_histogram('val/action_probs', action_hist, epoch_num) - fig = plt.Figure() - ax = fig.add_axes([0, 0, 1, 1]) - ax.bar(np.arange(self.actions_num), action_hist) - logger.add_figure('val/action_probs', fig, epoch_num) + # if discrete action space + if False: + action_hist = (self._action_probs / self._stored_steps).detach().cpu().numpy() + fig = plt.Figure() + ax = fig.add_axes([0, 0, 1, 1]) + ax.bar(np.arange(self.actions_num), action_hist) + logger.add_figure('val/action_probs', fig, epoch_num) + else: + # log mean +- std + pass logger.add_image('val/latent_probs', latent_hist, epoch_num, dataformats='HW') logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) @@ -569,52 +617,73 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation obs = self.preprocess_obs(self.from_np(obs)) a = self.from_np(a).to(torch.int64) - a = F.one_hot(a, num_classes=self.actions_num).squeeze() + if False: + a = F.one_hot(a, num_classes=self.actions_num).squeeze() r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) is_finished = self.from_np(is_finished) # take some latent embeddings as initial step - losses, discovered_latents = self.world_model.calculate_loss( - next_obs, a, r, is_finished) - - # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device - world_model_loss = torch.Tensor(1).to(next(self.world_model.parameters()).device) - for l in losses.values(): - world_model_loss += l - - self.world_model_optimizer.zero_grad() - world_model_loss.backward() - self.world_model_optimizer.step() - - # idx = torch.randperm(discovered_latents.size(0)) - # initial_states = discovered_latents[idx] - - # losses_ac = defaultdict( - # lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) - - # zs, action_dists, next_zs, rewards, terminal_flags, _ = self.imagine_trajectory( - # initial_states) - # vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) - - # losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), - # vs.detach(), - # reduction='sum') / zs.shape[1] - # losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - # losses_ac['loss_actor_dynamics_backprop'] = -( - # (1 - self.rho) * vs).sum() / zs.shape[1] - # losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( - # [a.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] - # losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ - # 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] - - # self.actor_optimizer.zero_grad() - # self.critic_optimizer.zero_grad() - # losses_ac['loss_critic'].backward() - # losses_ac['loss_actor'].backward() - # self.actor_optimizer.step() - # self.critic_optimizer.step() + with torch.cuda.amp.autocast(enabled=False): + losses, discovered_latents = self.world_model.calculate_loss( + next_obs, a, r, is_finished) + + # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device + # world_model_loss = torch.Tensor(1).to(next(self.world_model.parameters()).device) + world_model_loss = (losses['loss_reconstruction'] + + losses['loss_reward_pred'] + + losses['loss_kl_reg'] + + losses['loss_discount_pred']) + # for l in losses.values(): + # world_model_loss += l + + self.world_model_optimizer.zero_grad(set_to_none=True) + self.scaler.scale(world_model_loss).backward() + + # FIXME: clip gradient should be parametrized + self.scaler.unscale_(self.world_model_optimizer) + nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) + self.scaler.step(self.world_model_optimizer) + + # with torch.cuda.amp.autocast(enabled=False): + # idx = torch.randperm(discovered_latents.size(0)) + # initial_states = discovered_latents[idx] + + # losses_ac = defaultdict( + # lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) + + # zs, action_dists, next_zs, rewards, terminal_flags, _ = self.imagine_trajectory( + # initial_states) + # vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) + + # losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), + # vs.detach(), + # reduction='sum') / zs.shape[1] + # losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control + # losses_ac['loss_actor_dynamics_backprop'] = -( + # (1 - self.rho) * vs).sum() / zs.shape[1] + # # FIXME: Is it correct to use normal entropy with Tanh transformation + # losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( + # [a.base_dist.base_dist.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] + # losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ + # 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + + # self.actor_optimizer.zero_grad(set_to_none=True) + # self.critic_optimizer.zero_grad(set_to_none=True) + + # self.scaler.scale(losses_ac['loss_critic']).backward() + # self.scaler.scale(losses_ac['loss_actor']).backward() + + # self.scaler.unscale_(self.actor_optimizer) + # self.scaler.unscale_(self.critic_optimizer) + # nn.utils.clip_grad_norm_(self.actor.parameters(), 100) + # nn.utils.clip_grad_norm_(self.critic.parameters(), 100) + + # self.scaler.step(self.actor_optimizer) + # self.scaler.step(self.critic_optimizer) + # self.critic.update_target() + self.scaler.update() losses = {l: val.detach().cpu().item() for l, val in losses.items()} # losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index adbcd1e..6cc05c0 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -186,14 +186,16 @@ def _action_space(self): class DmEnv(Env): - def __init__(self, run_on_pixels: bool, obs_res: tuple[int, - int], repeat_action_num: int, + def __init__(self, run_on_pixels: bool, + camera_id: int, + obs_res: tuple[int, int], repeat_action_num: int, domain_name: str, task_name: str, transforms: list[ActionTransformer]): + self.camera_id = camera_id self.env: dmEnviron = suite.load(domain_name=domain_name, task_name=task_name) super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) def render(self): - return self.env.physics.render(*self.obs_res) + return self.env.physics.render(*self.obs_res, camera_id=self.camera_id) def _uncode_ts(self, ts: TimeStep) -> EnvStepResult: if self.run_on_pixels: @@ -214,7 +216,7 @@ def _step(self, action: Action, repeat_num: int) -> EnvStepResult: env_res = self._uncode_ts(self.env.step(action)) else: env_res = ts - env_res.reward = np.tanh(rew + (env_res.reward or 0.0)) + env_res.reward = rew + (env_res.reward or 0.0) return env_res def reset(self) -> EnvStepResult: diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index f00edb8..e704556 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -6,7 +6,7 @@ def fc_nn_generator(input_num: int, hidden_size: int, num_layers: int, intermediate_activation: t.Type[nn.Module] = nn.ReLU, - final_activation: t.Type[nn.Module] = nn.Identity): + final_activation: nn.Module = nn.Identity()): layers = [] layers.append(nn.Linear(input_num, hidden_size)) layers.append(nn.ReLU(inplace=True)) @@ -14,5 +14,5 @@ def fc_nn_generator(input_num: int, layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(intermediate_activation(inplace=True)) layers.append(nn.Linear(hidden_size, output_num)) - layers.append(final_activation()) + layers.append(final_activation) return nn.Sequential(*layers) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 39b8c4b..41970d0 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -69,7 +69,7 @@ def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): if f: self.add_rollout( Rollout(np.array(self.curr_rollout.states), - np.array(self.curr_rollout.actions), + np.array(self.curr_rollout.actions).reshape(len(self.curr_rollout.actions), -1), np.array(self.curr_rollout.rewards, dtype=np.float32), np.array([n]), np.array(self.curr_rollout.is_finished))) self.curr_rollout = None @@ -105,6 +105,10 @@ def sample( if r_idx == len(self.rollouts): r_len += 1 + # FIXME: hot-fix for 1d action space, better to find smarter solution + actions = np.array(rollout.actions[s_idx:s_idx + cluster_size]).reshape(cluster_size, -1) + else: + actions = rollout.actions[s_idx:s_idx + cluster_size] s.append(rollout.states[s_idx:s_idx + cluster_size]) a.append(rollout.actions[s_idx:s_idx + cluster_size]) From faf35b9e6b4be039e48332ccdbdcb411ce55d171 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 13 Jan 2023 10:28:57 +0000 Subject: [PATCH 025/106] Fix training for reconstruction - log_prob used instead of mse_loss - set number of ensembles to 1 - remove normalization layers which are absent in Dreamer Improved video logging: - to only imagine track added imagine+observe viz - same actions are applied accross replay rollout and imagined rollout --- config/agent/dreamer_v2.yaml | 2 +- config/config.yaml | 2 +- config/env/mock.yaml | 2 +- main.py | 12 ++-- rl_sandbox/agents/dreamer_v2.py | 111 +++++++++++++++++++----------- rl_sandbox/utils/replay_buffer.py | 2 +- 6 files changed, 81 insertions(+), 50 deletions(-) diff --git a/config/agent/dreamer_v2.yaml b/config/agent/dreamer_v2.yaml index 961aea0..ee50a82 100644 --- a/config/agent/dreamer_v2.yaml +++ b/config/agent/dreamer_v2.yaml @@ -4,7 +4,7 @@ batch_cluster_size: 32 latent_dim: 32 latent_classes: 32 rssm_dim: 200 -kl_loss_scale: 0.1 +kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 1.0 world_model_lr: 3e-4 diff --git a/config/config.yaml b/config/config.yaml index 4cd2968..9ad962c 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -10,7 +10,7 @@ device_type: cuda training: steps: 5e5 prefill: 1000 - pretrain: 1e3 + pretrain: 100 batch_size: 1024 gradient_steps_per_step: 5 save_checkpoint_every: 1e5 diff --git a/config/env/mock.yaml b/config/env/mock.yaml index ec28cf3..c62296c 100644 --- a/config/env/mock.yaml +++ b/config/env/mock.yaml @@ -1,5 +1,5 @@ _target_: rl_sandbox.utils.env.MockEnv run_on_pixels: true obs_res: [64, 64] -repeat_action_num: 1 +repeat_action_num: 5 transforms: [] diff --git a/main.py b/main.py index b973d31..c4f2166 100644 --- a/main.py +++ b/main.py @@ -77,22 +77,23 @@ def main(cfg: DictConfig): for loss_name, loss in losses.items(): writer.add_scalar(f'pre_train/{loss_name}', loss, i) - # if cfg.training.pretrain > 0: - if i % 2e2 == 0: + log_every_n = 25 + st = int(cfg.training.pretrain) // log_every_n + if i % log_every_n == 0: rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) # TODO: make logs visualization in separate process metrics = metrics_evaluator.calculate_metrics(rollouts) for metric_name, metric in metrics.items(): - writer.add_scalar(f'val/{metric_name}', metric, -10 + i/100) + writer.add_scalar(f'val/{metric_name}', metric, -st + i/log_every_n) if cfg.validation.visualize: rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, -10 + i/100) + writer.add_video('val/visualization', video, -st + i/log_every_n) # FIXME: Very bad from architecture point - agent.viz_log(rollout, writer, -10 + i/100) + agent.viz_log(rollout, writer, -st + i/log_every_n) global_step = 0 pbar = tqdm(total=cfg.training.steps, desc='Training') @@ -117,6 +118,7 @@ def main(cfg: DictConfig): global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) + # FIXME: Currently works only val_logs_every is multiplier of amount of steps per rollout ### Validation if global_step % cfg.training.val_logs_every == 0: with torch.no_grad(): diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 3e9447b..1c6fac7 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -112,14 +112,14 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): super().__init__() self.latent_dim = latent_dim self.latent_classes = latent_classes - self.ensemble_num = 5 + self.ensemble_num = 1 self.hidden_size = hidden_size # Calculate deterministic state from prev stochastic, prev action and prev deterministic self.pre_determ_recurrent = nn.Sequential( nn.Linear(latent_dim * latent_classes + actions_num, hidden_size), # Dreamer 'img_in' - nn.LayerNorm(hidden_size), + # nn.LayerNorm(hidden_size), nn.ELU(inplace=True) ) self.determ_recurrent = nn.GRU(input_size=hidden_size, @@ -127,11 +127,10 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): # Calculate stochastic state from prior embed # shared between all ensemble models - # FIXME: check whether it is trully correct self.ensemble_prior_estimator = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' - nn.LayerNorm(hidden_size), + # nn.LayerNorm(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' @@ -143,7 +142,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): img_sz = 4 * 384 # 384*2x2 self.stoch_net = nn.Sequential( nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' - nn.LayerNorm(hidden_size), + # nn.LayerNorm(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' @@ -153,7 +152,7 @@ def estimate_stochastic_latent(self, prev_determ): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of # taking only one random between all ensembles - # FIXME: temporary use the same model + # NOTE: in Dreamer ensemble_num is always 1 idx = torch.randint(0, self.ensemble_num, ()) return dists_per_model[idx] @@ -165,15 +164,16 @@ def predict_next(self, if deter_state is None: deter_state = torch.zeros(*stoch_latent.shape[:2], self.hidden_size).to( next(self.stoch_net.parameters()).device) - x = self.pre_determ_recurrent(torch.concat([stoch_latent, action], dim=2)) - _, determ = self.determ_recurrent(x, deter_state) + x = self.pre_determ_recurrent(torch.concat([stoch_latent, action], dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ = self.determ_recurrent(x, deter_state) # used for KL divergence - predicted_stoch_logits = self.estimate_stochastic_latent(determ) + predicted_stoch_logits = self.estimate_stochastic_latent(x) return determ, predicted_stoch_logits def update_current(self, determ, embed): # Dreamer 'obs_out' - return self.stoch_net(torch.concat([determ, embed], dim=2)) + return self.stoch_net(torch.concat([determ, embed], dim=-1)) def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, action): @@ -216,7 +216,7 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) - layers.append(nn.BatchNorm2d(out_channels)) + # layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels layers.append(nn.Flatten()) @@ -242,7 +242,7 @@ def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): out_channels = 3 layers.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) else: - layers.append(nn.BatchNorm2d(in_channels)) + # layers.append(nn.BatchNorm2d(in_channels)) layers.append( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) @@ -253,8 +253,7 @@ def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): def forward(self, X): x = self.convin(X) x = x.view(-1, 32 * self.channel_step, 1, 1) - return self.net(x) - # return td.Independent(td.Normal(self.net(x), 1.0), 3) + return td.Independent(td.Normal(self.net(x), 1.0), 3) class WorldModel(nn.Module): @@ -323,15 +322,12 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) - def KL(dist1, dist2, clusterify: bool = False): + def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) - if clusterify: - kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one).view(-1, self.cluster_size).mean(dim=0).sum() - kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one).view(-1, self.cluster_size).mean(dim=0).sum() - else: - kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)), one).mean() - kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())), one).mean() + # TODO: kl_free_avg is used always + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) return self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs) latent_vars = [] @@ -372,15 +368,12 @@ def KL(dist1, dist2, clusterify: bool = False): f_pred = self.discount_predictor(inp) x_r = self.image_predictor(torch.flatten(inp, 0, 1)) - losses['loss_reconstruction'] += F.mse_loss(x_r, obs) * 32 - #losses['loss_reward_pred'] += F.mse_loss(r, r_pred) - #losses['loss_discount_pred'] += F.binary_cross_entropy_with_logits(is_finished.type(torch.float32), f_pred) - # losses['loss_reconstruction'] += -x_r.log_prob(obs).mean() + losses['loss_reconstruction'] += -x_r.log_prob(obs).mean() losses['loss_reward_pred'] += -r_pred.log_prob(r_c).mean() losses['loss_discount_pred'] += -f_pred.log_prob(f_c.type(torch.float32)).mean() # NOTE: entropy can be added as metric losses['loss_kl_reg'] += KL(torch.flatten(torch.stack(prior_logits, dim=1), 0, 1), - torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1), False) + torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1)) return losses, torch.stack(latent_vars, dim=1).reshape(-1, self.latent_dim * self.latent_classes).detach() @@ -498,16 +491,24 @@ def __init__( self.reset() def imagine_trajectory( - self, z_0 + self, z_0, precomp_actions: t.Optional[list[Action]] = None, horizon: t.Optional[int] = None ) -> tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if horizon is None: + horizon = self.imagination_horizon world_state = None zs, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [] z = z_0.detach() - for _ in range(self.imagination_horizon): - a = self.actor(z) + for i in range(horizon): + if precomp_actions is not None: + a = precomp_actions[i].unsqueeze(0) + else: + a = self.actor(z).rsample() world_state, next_z, reward, is_finished = self.world_model.predict_next( - z, a.rsample(), world_state) + z, a, world_state) + # FIXME: + # is_finished should be shifted, as they imply whether the following state + # will be valid, not whether the current state is valid. zs.append(z) actions.append(a) @@ -522,6 +523,7 @@ def imagine_trajectory( def reset(self): self._state = None + # FIXME: instead of zero, it should be mode of distribution self._last_action = torch.zeros((1, 1, self.actions_num), device=next(self.world_model.parameters()).device) self._latent_probs = torch.zeros((32, 32), device=next(self.world_model.parameters()).device) @@ -561,36 +563,63 @@ def get_action(self, obs: Observation) -> Action: else: return self._last_action.squeeze().detach().cpu().numpy() - def _generate_video(self, obs: Observation, init_action: Action): + def _generate_video(self, obs: Observation, actions: list[Action]): obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) obs = self.preprocess_obs(obs).unsqueeze(0) if False: - action = F.one_hot(self.from_np(init_action).to(torch.int64), + actions = F.one_hot(self.from_np(actions).to(torch.int64), num_classes=self.actions_num).squeeze() else: - action = self.from_np(init_action) - z_0 = Dist(self.world_model.get_latent(obs, action.unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) - zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0)) + actions = self.from_np(actions) + z_0 = Dist(self.world_model.get_latent(obs, actions[0].unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) + zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0), actions[1:], horizon=self.imagination_horizon - 1) # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() - video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).cpu().detach().numpy() + video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=2)).mode.cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) return video_r + def _generate_video_with_update(self, obs: list[Observation], init_action: list[Action]): + obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) + obs = self.preprocess_obs(obs) + + if False: + action = F.one_hot(self.from_np(init_action).to(torch.int64), + num_classes=self.actions_num).squeeze() + else: + action = self.from_np(init_action) + state = None + video = [] + for o, a in zip(obs, action): + determ, stoch_logits = self.world_model.get_latent(o.unsqueeze(0), a.unsqueeze(0).unsqueeze(0), state) + z_0 = Dist(stoch_logits).rsample().reshape(-1, 32 * 32).unsqueeze(0) + state = (determ, z_0) + # zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0), horizon=1) + # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() + video_r = self.world_model.image_predictor(torch.concat([determ, z_0], dim=-1)).mode.cpu().detach().numpy() + video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) + video.append(video_r) + return np.concatenate(video) + def viz_log(self, rollout, logger, epoch_num): init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 3) videos_r = np.concatenate([ self._generate_video(obs_0.copy(), a_0) for obs_0, a_0 in zip( - rollout.states[init_indeces], rollout.actions[init_indeces]) - ], - axis=3) + rollout.next_states[init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) + ], axis=3) + + videos_r_update = np.concatenate([ + self._generate_video_with_update(obs_0.copy(), a_0) for obs_0, a_0 in zip( + [rollout.next_states[idx:idx+ self.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) + ], axis=3) videos = np.concatenate([ - rollout.states[init_idx:init_idx + self.imagination_horizon].transpose( + rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( 0, 3, 1, 2) for init_idx in init_indeces ], axis=3) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r], axis=2), 0) + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r_update, videos_r], axis=2), 0) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 41970d0..e0b432d 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -111,7 +111,7 @@ def sample( actions = rollout.actions[s_idx:s_idx + cluster_size] s.append(rollout.states[s_idx:s_idx + cluster_size]) - a.append(rollout.actions[s_idx:s_idx + cluster_size]) + a.append(actions) r.append(rollout.rewards[s_idx:s_idx + cluster_size]) t.append(rollout.is_finished[s_idx:s_idx + cluster_size]) if s_idx != r_len - cluster_size: From 224fd52370eaee54533a43361c80195c9eaedc00 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 14 Jan 2023 14:26:44 +0000 Subject: [PATCH 026/106] Fixed discount factor (invert is_finished), add lamda_return tests --- pyproject.toml | 17 ++- rl_sandbox/agents/dreamer_v2.py | 133 +++++++++--------- {config => rl_sandbox/config}/agent/dqn.yaml | 0 .../config}/agent/dreamer_v2.yaml | 0 {config => rl_sandbox/config}/config.yaml | 0 .../config}/env/dm_cartpole.yaml | 0 {config => rl_sandbox/config}/env/mock.yaml | 0 rl_sandbox/test/dreamer/test_critic.py | 62 ++++++++ .../test}/test_linear_scheduler.py | 0 .../test}/test_replay_buffer.py | 0 main.py => rl_sandbox/train.py | 0 11 files changed, 142 insertions(+), 70 deletions(-) rename {config => rl_sandbox/config}/agent/dqn.yaml (100%) rename {config => rl_sandbox/config}/agent/dreamer_v2.yaml (100%) rename {config => rl_sandbox/config}/config.yaml (100%) rename {config => rl_sandbox/config}/env/dm_cartpole.yaml (100%) rename {config => rl_sandbox/config}/env/mock.yaml (100%) create mode 100644 rl_sandbox/test/dreamer/test_critic.py rename {tests => rl_sandbox/test}/test_linear_scheduler.py (100%) rename {tests => rl_sandbox/test}/test_replay_buffer.py (100%) rename main.py => rl_sandbox/train.py (100%) diff --git a/pyproject.toml b/pyproject.toml index d16679d..10d0305 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,10 @@ name = 'rl_sandbox' version = "0.1.0" description = 'Sandbox for my RL experiments' authors = ['Roman Milishchuk '] -packages = [{include = 'rl_sandbox'}] +packages = [{include = 'rl_sandbox', exclude = 'rl_sandbox.test*'}] -[tool.yapf] -based_on_style = "pep8" -column_limit = 90 +[tool.poetry.scripts] +my-script = ['rl_sandbox.train:main'] [tool.poetry.dependencies] python = "^3.10" @@ -28,3 +27,13 @@ unpackable = '^0.0.4' hydra-core = "^1.2.0" matplotlib = "^3.0.0" webdataset = "^0.2.20" + +[tool.yapf] +based_on_style = "pep8" +column_limit = 90 + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", +] + diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 1c6fac7..08b072d 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -298,8 +298,8 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor -1, self.latent_dim * self.latent_classes) reward = self.reward_predictor( torch.concat([determ_state.squeeze(0), next_repr], dim=1)).rsample() - is_finished = self.discount_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)).sample() - return determ_state, next_repr, reward, is_finished + discount_factors = self.discount_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)).sample() + return determ_state, next_repr, reward, discount_factors def get_latent(self, obs: torch.Tensor, action, state): embed = self.encoder(obs) @@ -308,7 +308,7 @@ def get_latent(self, obs: torch.Tensor, action, state): return determ, latent_repr_logits def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - is_finished: torch.Tensor): + discount: torch.Tensor): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -317,7 +317,7 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, obs_c = obs.reshape(-1, self.cluster_size, 3, h, w) a_c = a.reshape(-1, self.cluster_size, self.actions_num) r_c = r.reshape(-1, self.cluster_size, 1) - f_c = is_finished.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) @@ -341,7 +341,7 @@ def KL(dist1, dist2): for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 x_t, embed_t, a_t, r_t, f_t = obs_c[:, t], embed_c[:, t].unsqueeze( - 0), a_c[:, t].unsqueeze(0), r_c[:, t], f_c[:, t] + 0), a_c[:, t].unsqueeze(0), r_c[:, t], d_c[:, t] determ_t, prior_stoch_logits, posterior_stoch_logits = self.recurrent_model.forward( h_prev, embed_t, a_t) @@ -370,7 +370,7 @@ def KL(dist1, dist2): losses['loss_reconstruction'] += -x_r.log_prob(obs).mean() losses['loss_reward_pred'] += -r_pred.log_prob(r_c).mean() - losses['loss_discount_pred'] += -f_pred.log_prob(f_c.type(torch.float32)).mean() + losses['loss_discount_pred'] += -f_pred.log_prob(d_c).mean() # NOTE: entropy can be added as metric losses['loss_kl_reg'] += KL(torch.flatten(torch.stack(prior_logits, dim=1), 0, 1), torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1)) @@ -381,8 +381,7 @@ def KL(dist1, dist2): class ImaginativeCritic(nn.Module): def __init__(self, discount_factor: float, update_interval: int, - soft_update_fraction: float, value_target_lambda: float, latent_dim: int, - actions_num: int): + soft_update_fraction: float, value_target_lambda: float, latent_dim: int): super().__init__() self.gamma = discount_factor self.critic_update_interval = update_interval @@ -413,16 +412,20 @@ def update_target(self): def estimate_value(self, z) -> torch.Tensor: return self.critic(z) - def lambda_return(self, zs, rs, ts): - v_lambdas = [self.target_critic(zs[-1])] - for i in range(zs.shape[0] - 2, -1, -1): - v_lambda = rs[i] + ts[i] * self.gamma * ( - (1 - self.lambda_) * self.target_critic(zs[i]).detach() + + def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): + v_lambdas = [rs[-1] + self.gamma*vs[-1]] + for i in range(vs.shape[0] - 2, -1, -1): + v_lambda = rs[i] + ds[i] * self.gamma * ( + (1 - self.lambda_) * vs[i] + self.lambda_ * v_lambdas[-1]) v_lambdas.append(v_lambda) return torch.stack(list(reversed(v_lambdas))) + def lambda_return(self, zs, rs, ds): + vs = self.target_critic(zs).detach() + return self._lambda_return(vs, rs, ds) + class DreamerV2(RlAgent): @@ -472,8 +475,7 @@ def __init__( self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, critic_value_target_lambda, - latent_dim * latent_classes, - actions_num).to(device_type) + latent_dim * latent_classes).to(device_type) self.scaler = torch.cuda.amp.GradScaler() self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), @@ -501,20 +503,22 @@ def imagine_trajectory( z = z_0.detach() for i in range(horizon): if precomp_actions is not None: + a_dist = None a = precomp_actions[i].unsqueeze(0) else: - a = self.actor(z).rsample() - world_state, next_z, reward, is_finished = self.world_model.predict_next( + a_dist = self.actor(z) + a = a_dist.rsample() + world_state, next_z, reward, discount = self.world_model.predict_next( z, a, world_state) # FIXME: - # is_finished should be shifted, as they imply whether the following state + # discount factors should be shifted, as they imply whether the following state # will be valid, not whether the current state is valid. zs.append(z) - actions.append(a) + actions.append(a_dist) next_zs.append(next_z) rewards.append(reward) - ts.append(is_finished) + ts.append(discount) determs.append(world_state[0]) z = next_z.detach() @@ -594,8 +598,6 @@ def _generate_video_with_update(self, obs: list[Observation], init_action: list[ determ, stoch_logits = self.world_model.get_latent(o.unsqueeze(0), a.unsqueeze(0).unsqueeze(0), state) z_0 = Dist(stoch_logits).rsample().reshape(-1, 32 * 32).unsqueeze(0) state = (determ, z_0) - # zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0), horizon=1) - # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() video_r = self.world_model.image_predictor(torch.concat([determ, z_0], dim=-1)).mode.cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) video.append(video_r) @@ -650,12 +652,12 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation a = F.one_hot(a, num_classes=self.actions_num).squeeze() r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) - is_finished = self.from_np(is_finished) + discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) - # take some latent embeddings as initial step + # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=False): losses, discovered_latents = self.world_model.calculate_loss( - next_obs, a, r, is_finished) + next_obs, a, r, discount_factors) # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device # world_model_loss = torch.Tensor(1).to(next(self.world_model.parameters()).device) @@ -674,51 +676,50 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) - # with torch.cuda.amp.autocast(enabled=False): - # idx = torch.randperm(discovered_latents.size(0)) - # initial_states = discovered_latents[idx] - - # losses_ac = defaultdict( - # lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) - - # zs, action_dists, next_zs, rewards, terminal_flags, _ = self.imagine_trajectory( - # initial_states) - # vs = self.critic.lambda_return(next_zs, rewards, terminal_flags) - - # losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), - # vs.detach(), - # reduction='sum') / zs.shape[1] - # losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - # losses_ac['loss_actor_dynamics_backprop'] = -( - # (1 - self.rho) * vs).sum() / zs.shape[1] - # # FIXME: Is it correct to use normal entropy with Tanh transformation - # losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( - # [a.base_dist.base_dist.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] - # losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ - # 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] - - # self.actor_optimizer.zero_grad(set_to_none=True) - # self.critic_optimizer.zero_grad(set_to_none=True) - - # self.scaler.scale(losses_ac['loss_critic']).backward() - # self.scaler.scale(losses_ac['loss_actor']).backward() - - # self.scaler.unscale_(self.actor_optimizer) - # self.scaler.unscale_(self.critic_optimizer) - # nn.utils.clip_grad_norm_(self.actor.parameters(), 100) - # nn.utils.clip_grad_norm_(self.critic.parameters(), 100) - - # self.scaler.step(self.actor_optimizer) - # self.scaler.step(self.critic_optimizer) - - # self.critic.update_target() + with torch.cuda.amp.autocast(enabled=False): + idx = torch.randperm(discovered_latents.size(0)) + initial_states = discovered_latents[idx] + + losses_ac = defaultdict( + lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) + + zs, action_dists, next_zs, rewards, discount_factors, _ = self.imagine_trajectory( + initial_states) + vs = self.critic.lambda_return(next_zs, rewards, discount_factors) + + losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), + vs.detach(), + reduction='sum') / zs.shape[1] + losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control + losses_ac['loss_actor_dynamics_backprop'] = -( + (1 - self.rho) * vs).sum() / zs.shape[1] + # FIXME: Is it correct to use normal entropy with Tanh transformation + losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( + [a.base_dist.base_dist.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] + losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ + 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + + self.actor_optimizer.zero_grad(set_to_none=True) + self.critic_optimizer.zero_grad(set_to_none=True) + + self.scaler.scale(losses_ac['loss_critic']).backward() + self.scaler.scale(losses_ac['loss_actor']).backward() + + self.scaler.unscale_(self.actor_optimizer) + self.scaler.unscale_(self.critic_optimizer) + nn.utils.clip_grad_norm_(self.actor.parameters(), 100) + nn.utils.clip_grad_norm_(self.critic.parameters(), 100) + + self.scaler.step(self.actor_optimizer) + self.scaler.step(self.critic_optimizer) + + self.critic.update_target() self.scaler.update() losses = {l: val.detach().cpu().item() for l, val in losses.items()} - # losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} + losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} - # return losses | losses_ac - return losses + return losses | losses_ac def save_ckpt(self, epoch_num: int, losses: dict[str, float]): torch.save( diff --git a/config/agent/dqn.yaml b/rl_sandbox/config/agent/dqn.yaml similarity index 100% rename from config/agent/dqn.yaml rename to rl_sandbox/config/agent/dqn.yaml diff --git a/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml similarity index 100% rename from config/agent/dreamer_v2.yaml rename to rl_sandbox/config/agent/dreamer_v2.yaml diff --git a/config/config.yaml b/rl_sandbox/config/config.yaml similarity index 100% rename from config/config.yaml rename to rl_sandbox/config/config.yaml diff --git a/config/env/dm_cartpole.yaml b/rl_sandbox/config/env/dm_cartpole.yaml similarity index 100% rename from config/env/dm_cartpole.yaml rename to rl_sandbox/config/env/dm_cartpole.yaml diff --git a/config/env/mock.yaml b/rl_sandbox/config/env/mock.yaml similarity index 100% rename from config/env/mock.yaml rename to rl_sandbox/config/env/mock.yaml diff --git a/rl_sandbox/test/dreamer/test_critic.py b/rl_sandbox/test/dreamer/test_critic.py new file mode 100644 index 0000000..6d90ea5 --- /dev/null +++ b/rl_sandbox/test/dreamer/test_critic.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from rl_sandbox.agents.dreamer_v2 import ImaginativeCritic + +@pytest.fixture +def imaginative_critic(): + return ImaginativeCritic(discount_factor=1, + update_interval=100, + soft_update_fraction=1, + value_target_lambda=0.95, + latent_dim=10) + +def test_lambda_return_discount_0(imaginative_critic): + # Should just return rewards if discount_factor is 0 + imaginative_critic.lambda_ = 0 + imaginative_critic.gamma = 0 + rs = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + vs = torch.ones_like(rs) + ts = torch.ones_like(rs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == rs) + +def test_lambda_return_lambda_0(imaginative_critic): + # Should return 1-step return if lambda is 0 + imaginative_critic.lambda_ = 0 + imaginative_critic.gamma = 1 + vs = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + rs = torch.ones_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([2, 3, 4, 5, 6, 7, 8, 9, 10, 11])) + +def test_lambda_return_lambda_0_gamma_0_5(imaginative_critic): + # Should return 1-step return if lambda is 0 + imaginative_critic.lambda_ = 0 + imaginative_critic.gamma = 0.5 + vs = torch.Tensor([2, 2, 4, 4, 6, 6, 8, 8, 10, 10]) + rs = torch.ones_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([2, 2, 3, 3, 4, 4, 5, 5, 6, 6])) + +def test_lambda_return_lambda_1(imaginative_critic): + # Should return Monte-Carlo return + imaginative_critic.lambda_ = 1 + imaginative_critic.gamma = 1 + vs = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + rs = torch.ones_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([20, 19, 18, 17, 16, 15, 14, 13, 12, 11])) + +def test_lambda_return_lambda_1_gamma_0_5(imaginative_critic): + # Should return Monte-Carlo return + imaginative_critic.lambda_ = 1 + imaginative_critic.gamma = 0.5 + vs = torch.Tensor([2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]) + rs = torch.zeros_like(vs) + ts = torch.ones_like(vs) + lambda_ret = imaginative_critic._lambda_return(vs, rs, ts) + assert torch.all(lambda_ret == torch.Tensor([1, 2, 4, 8, 16, 32, 64, 128, 256, 512])) diff --git a/tests/test_linear_scheduler.py b/rl_sandbox/test/test_linear_scheduler.py similarity index 100% rename from tests/test_linear_scheduler.py rename to rl_sandbox/test/test_linear_scheduler.py diff --git a/tests/test_replay_buffer.py b/rl_sandbox/test/test_replay_buffer.py similarity index 100% rename from tests/test_replay_buffer.py rename to rl_sandbox/test/test_replay_buffer.py diff --git a/main.py b/rl_sandbox/train.py similarity index 100% rename from main.py rename to rl_sandbox/train.py From ce915667fe77cb688fdec337467b4216e6b70df3 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 14 Jan 2023 16:57:25 +0000 Subject: [PATCH 027/106] Changed discount factor to enforce 0 after first 0, fixed edge cases for actor loss --- rl_sandbox/agents/dreamer_v2.py | 42 ++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 08b072d..e8be537 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -5,9 +5,9 @@ import matplotlib.pyplot as plt import numpy as np import torch +import torch.distributions as td from torch import nn from torch.nn import functional as F -import torch.distributions as td from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator @@ -393,12 +393,14 @@ def __init__(self, discount_factor: float, update_interval: int, 1, 400, 1, - intermediate_activation=nn.ELU) + intermediate_activation=nn.ELU, + final_activation=DistLayer('mse')) self.target_critic = fc_nn_generator(latent_dim, 1, 400, 1, - intermediate_activation=nn.ELU) + intermediate_activation=nn.ELU, + final_activation=DistLayer('mse')) def update_target(self): if self._update_num == 0: @@ -409,7 +411,7 @@ def update_target(self): (1 - mix) * target_param.data) self._update_num = (self._update_num + 1) % self.critic_update_interval - def estimate_value(self, z) -> torch.Tensor: + def estimate_value(self, z) -> td.Distribution: return self.critic(z) def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): @@ -423,7 +425,7 @@ def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): return torch.stack(list(reversed(v_lambdas))) def lambda_return(self, zs, rs, ds): - vs = self.target_critic(zs).detach() + vs = self.target_critic(zs).rsample().detach() return self._lambda_return(vs, rs, ds) @@ -510,9 +512,6 @@ def imagine_trajectory( a = a_dist.rsample() world_state, next_z, reward, discount = self.world_model.predict_next( z, a, world_state) - # FIXME: - # discount factors should be shifted, as they imply whether the following state - # will be valid, not whether the current state is valid. zs.append(z) actions.append(a_dist) @@ -655,7 +654,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) # take some latent embeddings as initial - with torch.cuda.amp.autocast(enabled=False): + with torch.cuda.amp.autocast(enabled=True): losses, discovered_latents = self.world_model.calculate_loss( next_obs, a, r, discount_factors) @@ -676,7 +675,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) - with torch.cuda.amp.autocast(enabled=False): + with torch.cuda.amp.autocast(enabled=True): idx = torch.randperm(discovered_latents.size(0)) initial_states = discovered_latents[idx] @@ -685,17 +684,28 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation zs, action_dists, next_zs, rewards, discount_factors, _ = self.imagine_trajectory( initial_states) + + # Ignore all factors after first is_finished state + discount_factors = torch.cumprod(discount_factors, dim=1) + + # Discounted factors should be shifted as they predict whether next state is terminal + # First discount factor on contrary is always 1 as it cannot lead to trajectory finish + discount_factors = torch.cat([torch.ones_like(discount_factors[:1, :]), discount_factors[:-1, :]], dim=0) + vs = self.critic.lambda_return(next_zs, rewards, discount_factors) - losses_ac['loss_critic'] = F.mse_loss(self.critic.estimate_value(next_zs.detach()), - vs.detach(), - reduction='sum') / zs.shape[1] + losses_ac['loss_critic'] = -(self.critic.estimate_value(next_zs.detach()).log_prob( + vs[:-1].detach()).unsqueeze(-1) * discount_factors).mean() + + # last action should be ignored as it is not used to predict next state, thus no feedback + # first value should be ignored as it is comes from replay buffer losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control losses_ac['loss_actor_dynamics_backprop'] = -( - (1 - self.rho) * vs).sum() / zs.shape[1] + (1 - self.rho) * (vs[1:-1]*discount_factors[1:-1])).mean() # FIXME: Is it correct to use normal entropy with Tanh transformation - losses_ac['loss_actor_entropy'] = -self.eta * torch.stack( - [a.base_dist.base_dist.entropy() for a in action_dists[:-1]]).sum() / zs.shape[1] + losses_ac['loss_actor_entropy'] = -(self.eta * + torch.stack([a.base_dist.base_dist.entropy() for a in action_dists[:-1]]) * discount_factors[-1]).mean() + losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] From c4809d720681e3c493f989a5bb20d13a17d89fb2 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 14 Jan 2023 22:35:05 +0000 Subject: [PATCH 028/106] Minor fixes, which doesn't fix anything --- .vimspector.json | 2 +- pyproject.toml | 5 +---- rl_sandbox/agents/dreamer_v2.py | 26 ++++++++++++-------------- rl_sandbox/train.py | 19 +++++++++---------- rl_sandbox/utils/env.py | 2 +- 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/.vimspector.json b/.vimspector.json index 5138480..4dfda9e 100644 --- a/.vimspector.json +++ b/.vimspector.json @@ -39,7 +39,7 @@ "Run main": { "extends": "python-base", "configuration": { - "program": "main.py", + "program": "rl_sandbox/train.py", "args": [] } } diff --git a/pyproject.toml b/pyproject.toml index 10d0305..c4bdc92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,7 @@ name = 'rl_sandbox' version = "0.1.0" description = 'Sandbox for my RL experiments' authors = ['Roman Milishchuk '] -packages = [{include = 'rl_sandbox', exclude = 'rl_sandbox.test*'}] - -[tool.poetry.scripts] -my-script = ['rl_sandbox.train:main'] +packages = [{include = 'rl_sandbox'}] [tool.poetry.dependencies] python = "^3.10" diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index e8be537..139bef0 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -348,13 +348,6 @@ def KL(dist1, dist2): posterior_stoch = Dist(posterior_stoch_logits).rsample().reshape( -1, self.latent_dim * self.latent_classes) - # inp_t = torch.concat([determ_t.squeeze(0), posterior_stoch], dim=-1) - # x_r = self.image_predictor(inp_t) - # inps.append(inp_t) - # reconstructed.append(x_r) - #losses['loss_reconstruction'] += nn.functional.mse_loss(x_t, x_r) - #losses['loss_kl_reg'] += KL(prior_stoch_logits, posterior_stoch_logits) - h_prev = [determ_t, posterior_stoch.unsqueeze(0)] determ_vars.append(determ_t.squeeze(0)) latent_vars.append(posterior_stoch) @@ -392,13 +385,13 @@ def __init__(self, discount_factor: float, update_interval: int, self.critic = fc_nn_generator(latent_dim, 1, 400, - 1, + 4, intermediate_activation=nn.ELU, final_activation=DistLayer('mse')) self.target_critic = fc_nn_generator(latent_dim, 1, 400, - 1, + 4, intermediate_activation=nn.ELU, final_activation=DistLayer('mse')) @@ -676,8 +669,10 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation self.scaler.step(self.world_model_optimizer) with torch.cuda.amp.autocast(enabled=True): - idx = torch.randperm(discovered_latents.size(0)) - initial_states = discovered_latents[idx] + # idx = torch.randperm(discovered_latents.size(0)) + # initial_states = discovered_latents[idx] + # Dreamer does not shuffle + initial_states = discovered_latents losses_ac = defaultdict( lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) @@ -685,17 +680,20 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation zs, action_dists, next_zs, rewards, discount_factors, _ = self.imagine_trajectory( initial_states) + # Discount prediction is disabled for dmc vision in Dreamer + discount_factors = self.critic.gamma * torch.zeros_like(rewards) + # Ignore all factors after first is_finished state - discount_factors = torch.cumprod(discount_factors, dim=1) + discount_factors = torch.cumprod(discount_factors, dim=1).detach() # Discounted factors should be shifted as they predict whether next state is terminal # First discount factor on contrary is always 1 as it cannot lead to trajectory finish discount_factors = torch.cat([torch.ones_like(discount_factors[:1, :]), discount_factors[:-1, :]], dim=0) - vs = self.critic.lambda_return(next_zs, rewards, discount_factors) + vs = self.critic.lambda_return(next_zs, rewards, discount_factors).detach() losses_ac['loss_critic'] = -(self.critic.estimate_value(next_zs.detach()).log_prob( - vs[:-1].detach()).unsqueeze(-1) * discount_factors).mean() + vs).unsqueeze(-1) * discount_factors).mean() # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index c4f2166..06fffe8 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -4,6 +4,7 @@ from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm from pathlib import Path +import random import torch from torch.profiler import profile, record_function, ProfilerActivity @@ -14,7 +15,7 @@ from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer from rl_sandbox.utils.persistent_replay_buffer import PersistentReplayBuffer -from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, iter_rollout_async, +from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, fillup_replay_buffer) from rl_sandbox.utils.schedulers import LinearScheduler @@ -37,10 +38,12 @@ def main(cfg: DictConfig): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True + random.seed(cfg.seed) + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + env: Env = hydra.utils.instantiate(cfg.env) - # TODO: add replay buffer implementation, which stores rollouts - # on disk buff = ReplayBuffer() fillup_replay_buffer(env, buff, max(cfg.training.prefill, cfg.training.batch_size)) @@ -48,7 +51,6 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape[0], # FIXME: feels bad @@ -57,12 +59,7 @@ def main(cfg: DictConfig): actions_num=env.action_space.shape[0], action_type='continuous', device_type=cfg.device_type) - # agent = ExplorativeAgent( - # policy_agent, - # # TODO: For dreamer, add noise for sampling instead - # # of just random actions - # RandomAgent(env), - # LinearScheduler(0.9, 0.01, 5_000)) + writer = SummaryWriter() prof = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], @@ -79,6 +76,8 @@ def main(cfg: DictConfig): log_every_n = 25 st = int(cfg.training.pretrain) // log_every_n + # FIXME: extract logging to seperate entity to omit + # copy-paste if i % log_every_n == 0: rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) # TODO: make logs visualization in separate process diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 6cc05c0..58e08bb 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -227,7 +227,7 @@ def _observation_space(self): return gym.spaces.Box(0, 255, self.obs_res + (3, ), dtype=np.uint8) else: raise NotImplementedError( - "Currently run on pixels is supported for 'dm_control'") + "Currently run on pixels is only supported for 'dm_control'") # for space in self.env.observation_spec(): # obs_space_num = sum([v.shape[0] for v in env.observation_space().values()]) From eacce41a67cde9af909fd58ac57556c65ff431e0 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 14 Jan 2023 23:47:10 +0000 Subject: [PATCH 029/106] Added metrics for AC --- rl_sandbox/agents/dreamer_v2.py | 48 ++++++++++++++++++++++----------- rl_sandbox/train.py | 6 +++-- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 139bef0..33219d8 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -297,7 +297,7 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor next_repr = Dist(next_repr_logits).rsample().reshape( -1, self.latent_dim * self.latent_classes) reward = self.reward_predictor( - torch.concat([determ_state.squeeze(0), next_repr], dim=1)).rsample() + torch.concat([determ_state.squeeze(0), next_repr], dim=1)).mode discount_factors = self.discount_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)).sample() return determ_state, next_repr, reward, discount_factors @@ -418,7 +418,7 @@ def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): return torch.stack(list(reversed(v_lambdas))) def lambda_return(self, zs, rs, ds): - vs = self.target_critic(zs).rsample().detach() + vs = self.target_critic(zs).mode.detach() return self._lambda_return(vs, rs, ds) @@ -448,6 +448,7 @@ def __init__( critic_lr: float, device_type: str = 'cpu'): + self.device = device_type self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size self.actions_num = actions_num @@ -515,15 +516,14 @@ def imagine_trajectory( z = next_z.detach() return (torch.stack(zs), actions, torch.stack(next_zs), - torch.stack(rewards).detach(), torch.stack(ts).detach(), torch.stack(determs)) + torch.stack(rewards), torch.stack(ts), torch.stack(determs)) def reset(self): self._state = None # FIXME: instead of zero, it should be mode of distribution - self._last_action = torch.zeros((1, 1, self.actions_num), - device=next(self.world_model.parameters()).device) - self._latent_probs = torch.zeros((32, 32), device=next(self.world_model.parameters()).device) - self._action_probs = torch.zeros((self.actions_num), device=next(self.world_model.parameters()).device) + self._last_action = torch.zeros((1, 1, self.actions_num), device=self.device) + self._latent_probs = torch.zeros((32, 32), device=self.device) + self._action_probs = torch.zeros((self.actions_num), device=self.device) self._stored_steps = 0 @staticmethod @@ -537,7 +537,7 @@ def preprocess_obs(obs: torch.Tensor): def get_action(self, obs: Observation) -> Action: # NOTE: pytorch fails without .copy() only when get_action is called - obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) + obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs).unsqueeze(0) determ, latent_repr_logits = self.world_model.get_latent(obs, self._last_action, @@ -560,7 +560,7 @@ def get_action(self, obs: Observation) -> Action: return self._last_action.squeeze().detach().cpu().numpy() def _generate_video(self, obs: Observation, actions: list[Action]): - obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) + obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs).unsqueeze(0) if False: @@ -576,7 +576,7 @@ def _generate_video(self, obs: Observation, actions: list[Action]): return video_r def _generate_video_with_update(self, obs: list[Observation], init_action: list[Action]): - obs = torch.from_numpy(obs.copy()).to(next(self.world_model.parameters()).device) + obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs) if False: @@ -633,7 +633,7 @@ def viz_log(self, rollout, logger, epoch_num): def from_np(self, arr: np.ndarray): arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr - return arr.to(next(self.world_model.parameters()).device, non_blocking=True) + return arr.to(self.device, non_blocking=True) def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observations, is_finished: TerminationFlags): @@ -652,7 +652,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation next_obs, a, r, discount_factors) # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device - # world_model_loss = torch.Tensor(1).to(next(self.world_model.parameters()).device) + # world_model_loss = torch.Tensor(1).to(self.device) world_model_loss = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + losses['loss_kl_reg'] + @@ -668,6 +668,9 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) + metrics = defaultdict( + lambda: torch.zeros(1).to(self.device)) + with torch.cuda.amp.autocast(enabled=True): # idx = torch.randperm(discovered_latents.size(0)) # initial_states = discovered_latents[idx] @@ -675,7 +678,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation initial_states = discovered_latents losses_ac = defaultdict( - lambda: torch.zeros(1).to(next(self.critic.parameters()).device)) + lambda: torch.zeros(1).to(self.device)) zs, action_dists, next_zs, rewards, discount_factors, _ = self.imagine_trajectory( initial_states) @@ -692,8 +695,13 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation vs = self.critic.lambda_return(next_zs, rewards, discount_factors).detach() - losses_ac['loss_critic'] = -(self.critic.estimate_value(next_zs.detach()).log_prob( - vs).unsqueeze(-1) * discount_factors).mean() + predicted_vs_dist = self.critic.estimate_value(next_zs.detach()) + losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs).unsqueeze(-1) * discount_factors).mean() + + metrics['critic/avg_value'] = self.critic.target_critic(next_zs).mode.mean() + for i in range(self.imagination_horizon): + metrics[f'critic/avg_lambda_value_{i}'] = vs[i].mean() + metrics['critic/avg_predicted_value'] = predicted_vs_dist.mode.mean() # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer @@ -707,6 +715,13 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + # mean and std are estimated statistically as tanh transformation is used + act_avg = torch.stack([a.expand((128, *a.batch_shape)).sample().mean(dim=0) for a in action_dists[:-1]]) + metrics['actor/avg_val'] = act_avg.mean() + metrics['actor/avg_sd'] = ((torch.stack([a.expand((128, *a.batch_shape)).sample() for a in action_dists[:-1]], dim=1) - act_avg)**2).mean().sqrt() + metrics['actor/min_val'] = torch.stack([a.expand((128, *a.batch_shape)).sample() for a in action_dists[:-1]]).min() + metrics['actor/max_val'] = torch.stack([a.expand((128, *a.batch_shape)).sample() for a in action_dists[:-1]]).max() + self.actor_optimizer.zero_grad(set_to_none=True) self.critic_optimizer.zero_grad(set_to_none=True) @@ -726,8 +741,9 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses = {l: val.detach().cpu().item() for l, val in losses.items()} losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} + metrics = {l: val.detach().cpu().item() for l, val in metrics.items()} - return losses | losses_ac + return losses | losses_ac | metrics def save_ckpt(self, epoch_num: int, losses: dict[str, float]): torch.save( diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 06fffe8..cab1c5a 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -112,8 +112,10 @@ def main(cfg: DictConfig): losses = agent.train(s, a, r, n, f) if cfg.debug.profiler: prof.step() - for loss_name, loss in losses.items(): - writer.add_scalar(f'train/{loss_name}', loss, global_step) + # NOTE: Do not forget to run test with every step to check for outliers + if global_step % 10 == 0: + for loss_name, loss in losses.items(): + writer.add_scalar(f'train/{loss_name}', loss, global_step) global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) From 6ac194ffd402a5c9962c3372c0addaa2d65358f8 Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 19 Jan 2023 17:10:20 +0000 Subject: [PATCH 030/106] WIP --- rl_sandbox/agents/dreamer_v2.py | 118 +++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 32 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 33219d8..29f23ac 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -15,6 +15,8 @@ Observations, Rewards, TerminationFlags) +from rl_sandbox.TruncatedNormal import TruncatedNormal + class View(nn.Module): @@ -35,6 +37,8 @@ def __init__(self, min_std: float, std_trans: str = 'sigmoid2', transform: t.Opt super().__init__() self.min_std = min_std match std_trans: + case 'identity': + self.std_trans = nn.Identity() case 'softplus': self.std_trans = nn.Softplus() case 'sigmoid': @@ -72,7 +76,19 @@ def __init__(self, type: str): case 'onehot': self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x) case 'normal_tanh': - self.dist = NormalWithOffset(min_std=0.1, transform='tanh') + def get_tanh_normal(x, min_std=0.1): + mean, std = x.chunk(2, dim=-1) + mean = 5 * torch.tanh(mean / 5) + std = F.softplus(std) + min_std + dist = td.Normal(mean, std) + return td.TransformedDistribution(dist, [td.TanhTransform(cache_size=1)]) + self.dist = get_tanh_normal + case 'normal_trunc': + def get_trunc_normal(x, min_std=0.1): + mean, std = x.chunk(2, dim=-1) + std = 2 * torch.sigmoid(std / 2) + min_std + return TruncatedNormal(torch.tanh(mean), std, a=-1 + 1e-6, b=1 - 1e-6) + self.dist = get_trunc_normal case 'binary': self.dist = lambda x: td.Bernoulli(logits=x) case _: @@ -256,6 +272,28 @@ def forward(self, X): return td.Independent(td.Normal(self.net(x), 1.0), 3) +class Normalizer(nn.Module): + def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): + super().__init__() + self.momentum = momentum + self.scale = scale + self.eps= eps + self.mag = torch.ones(1, dtype=torch.float32) + self.mag.requires_grad = False + + def forward(self, x): + self.update(x) + val = self.trans(x) + return val + + def update(self, x): + self.mag = self.momentum * self.mag.to(x.device) + (1 - self.momentum) * (x.abs().mean()).detach() + + def trans(self, x): + x = x / (self.mag + self.eps) + return x*self.scale + + class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, @@ -289,6 +327,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, num_layers=4, intermediate_activation=nn.ELU, final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=0.99, scale=1.0, eps=1e-8) def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): determ_state, next_repr_logits = self.recurrent_model.predict_next( @@ -418,7 +457,7 @@ def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): return torch.stack(list(reversed(v_lambdas))) def lambda_return(self, zs, rs, ds): - vs = self.target_critic(zs).mode.detach() + vs = self.target_critic(zs).mode return self._lambda_return(vs, rs, ds) @@ -491,32 +530,33 @@ def __init__( def imagine_trajectory( self, z_0, precomp_actions: t.Optional[list[Action]] = None, horizon: t.Optional[int] = None ) -> tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + torch.Tensor, torch.Tensor, torch.Tensor]: if horizon is None: horizon = self.imagination_horizon world_state = None - zs, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [] - z = z_0.detach() + zs, actions_dists, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [], [] + z = z_0 for i in range(horizon): if precomp_actions is not None: a_dist = None a = precomp_actions[i].unsqueeze(0) else: - a_dist = self.actor(z) + a_dist = self.actor(z.detach()) a = a_dist.rsample() world_state, next_z, reward, discount = self.world_model.predict_next( z, a, world_state) zs.append(z) - actions.append(a_dist) + actions_dists.append(a_dist) next_zs.append(next_z) rewards.append(reward) ts.append(discount) determs.append(world_state[0]) + actions.append(a) - z = next_z.detach() - return (torch.stack(zs), actions, torch.stack(next_zs), - torch.stack(rewards), torch.stack(ts), torch.stack(determs)) + z = next_z + return (torch.stack(zs), actions_dists, torch.stack(next_zs), + torch.stack(rewards), torch.stack(ts), torch.stack(determs), torch.stack(actions)) def reset(self): self._state = None @@ -569,7 +609,7 @@ def _generate_video(self, obs: Observation, actions: list[Action]): else: actions = self.from_np(actions) z_0 = Dist(self.world_model.get_latent(obs, actions[0].unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) - zs, _, _, _, _, determs = self.imagine_trajectory(z_0.squeeze(0), actions[1:], horizon=self.imagination_horizon - 1) + zs, _, _, _, _, determs, _ = self.imagine_trajectory(z_0.squeeze(0), actions[1:], horizon=self.imagination_horizon - 1) # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=2)).mode.cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) @@ -680,47 +720,61 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac = defaultdict( lambda: torch.zeros(1).to(self.device)) - zs, action_dists, next_zs, rewards, discount_factors, _ = self.imagine_trajectory( + zs, action_dists, next_zs, rewards, discount_factors, _, actions = self.imagine_trajectory( initial_states) + rewards = self.world_model.reward_normalizer(rewards) # Discount prediction is disabled for dmc vision in Dreamer - discount_factors = self.critic.gamma * torch.zeros_like(rewards) + discount_factors = self.critic.gamma * torch.ones_like(rewards) # Ignore all factors after first is_finished state - discount_factors = torch.cumprod(discount_factors, dim=1).detach() + discount_factors = torch.cumprod(discount_factors, dim=0).detach() # Discounted factors should be shifted as they predict whether next state is terminal # First discount factor on contrary is always 1 as it cannot lead to trajectory finish discount_factors = torch.cat([torch.ones_like(discount_factors[:1, :]), discount_factors[:-1, :]], dim=0) - vs = self.critic.lambda_return(next_zs, rewards, discount_factors).detach() + # vs = self.critic.lambda_return(zs[1:], rewards[:-1], discount_factors[:-1]) + vs = rewards[:-1] + self.critic.gamma * self.critic.target_critic(zs[1:]).mode + # vs = torch.cat([vs, rewards[-1].unsqueeze(0)], dim=0) + predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) + losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(-1) * discount_factors[:-1]).mean() - predicted_vs_dist = self.critic.estimate_value(next_zs.detach()) - losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs).unsqueeze(-1) * discount_factors).mean() + # predicted_vs_dist = self.critic.estimate_value(zs[:-2].detach()) + # losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs[:-1].detach()).unsqueeze(-1) * discount_factors[:-2]).mean() - metrics['critic/avg_value'] = self.critic.target_critic(next_zs).mode.mean() - for i in range(self.imagination_horizon): - metrics[f'critic/avg_lambda_value_{i}'] = vs[i].mean() + metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mode.mean() + metrics['critic/avg_lambda_value'] = vs.mean() metrics['critic/avg_predicted_value'] = predicted_vs_dist.mode.mean() # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer - losses_ac['loss_actor_reinforce'] += 0 # unused in dm_control - losses_ac['loss_actor_dynamics_backprop'] = -( - (1 - self.rho) * (vs[1:-1]*discount_factors[1:-1])).mean() - # FIXME: Is it correct to use normal entropy with Tanh transformation - losses_ac['loss_actor_entropy'] = -(self.eta * - torch.stack([a.base_dist.base_dist.entropy() for a in action_dists[:-1]]) * discount_factors[-1]).mean() + baseline = self.critic.target_critic(next_zs[1:-2]).mode + advantage = (vs[1:-1] - baseline).detach() + losses_ac['loss_actor_reinforce'] += -(self.rho * (torch.stack([a.log_prob(actions[idx].detach()) for idx, a in enumerate(action_dists[1:-2])]) * discount_factors[1:-2].squeeze()) * advantage.squeeze()).mean() + losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[:-1]*discount_factors[:-2])).mean() + + def calculate_entropy(dist): + # x_t = dist.base_dist.base_dist.rsample() + # return -(dist.base_dist.base_dist.log_prob(x_t) - torch.log(1 - torch.tanh(x_t).pow(2) + 1e-6)).sum(1) + return -dist.log_prob(dist.rsample((128,))).mean(0) + + losses_ac['loss_actor_entropy'] += -(self.eta * + torch.stack([calculate_entropy(a) for a in action_dists[1:-2]]) * discount_factors[1:-2].squeeze()).mean() + # losses_ac['loss_actor_entropy'] += (self.eta * + # torch.stack([a.entropy() for a in action_dists[1:-1]]) * discount_factors[1:-1].squeeze()).mean() - losses_ac['loss_actor'] += losses_ac['loss_actor_reinforce'] + losses_ac[ - 'loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] + losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used - act_avg = torch.stack([a.expand((128, *a.batch_shape)).sample().mean(dim=0) for a in action_dists[:-1]]) + act_avg = torch.stack([a.sample((128,)).mean(dim=0) for a in action_dists[:-1]]) + # act_avg = torch.stack([a.mean for a in action_dists[:-1]]) metrics['actor/avg_val'] = act_avg.mean() - metrics['actor/avg_sd'] = ((torch.stack([a.expand((128, *a.batch_shape)).sample() for a in action_dists[:-1]], dim=1) - act_avg)**2).mean().sqrt() - metrics['actor/min_val'] = torch.stack([a.expand((128, *a.batch_shape)).sample() for a in action_dists[:-1]]).min() - metrics['actor/max_val'] = torch.stack([a.expand((128, *a.batch_shape)).sample() for a in action_dists[:-1]]).max() + metrics['actor/mode_val'] = torch.stack([torch.mode(a.sample((128,)), 0)[0] for a in action_dists[:-1]]).mean() + # metrics['actor/avg_sd'] = torch.stack([a.stddev for a in action_dists[:-1]]).mean() + metrics['actor/avg_sd'] = (((torch.stack([a.sample((128,)) for a in action_dists[:-1]], dim=1) - act_avg)**2).mean(0).sqrt()).mean() + metrics['actor/min_val'] = torch.stack([a.sample((128,)) for a in action_dists[:-1]]).min() + metrics['actor/max_val'] = torch.stack([a.sample((128,)) for a in action_dists[:-1]]).max() self.actor_optimizer.zero_grad(set_to_none=True) self.critic_optimizer.zero_grad(set_to_none=True) From 3c4d94e7814b5c1be3de8a76af13afb27fe7fba6 Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 19 Jan 2023 19:51:48 +0000 Subject: [PATCH 031/106] Using Tanh/Trunc distributions from torchrl, fixed number of layers in FCNN, calculate entropy differently --- rl_sandbox/agents/dreamer_v2.py | 42 ++++++++++++++++----------------- rl_sandbox/utils/fc_nn.py | 3 ++- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 29f23ac..574069a 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -15,7 +15,8 @@ Observations, Rewards, TerminationFlags) -from rl_sandbox.TruncatedNormal import TruncatedNormal +from torchviz import make_dot +from torchrl.modules import TanhNormal, TruncatedNormal class View(nn.Module): @@ -78,16 +79,12 @@ def __init__(self, type: str): case 'normal_tanh': def get_tanh_normal(x, min_std=0.1): mean, std = x.chunk(2, dim=-1) - mean = 5 * torch.tanh(mean / 5) - std = F.softplus(std) + min_std - dist = td.Normal(mean, std) - return td.TransformedDistribution(dist, [td.TanhTransform(cache_size=1)]) + return TanhNormal(mean, F.softplus(std) + min_std, upscale=5) self.dist = get_tanh_normal case 'normal_trunc': def get_trunc_normal(x, min_std=0.1): mean, std = x.chunk(2, dim=-1) - std = 2 * torch.sigmoid(std / 2) + min_std - return TruncatedNormal(torch.tanh(mean), std, a=-1 + 1e-6, b=1 - 1e-6) + return TruncatedNormal(mean, 2*torch.sigmoid(std) + min_std, upscale=2) self.dist = get_trunc_normal case 'binary': self.dist = lambda x: td.Bernoulli(logits=x) @@ -720,7 +717,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac = defaultdict( lambda: torch.zeros(1).to(self.device)) - zs, action_dists, next_zs, rewards, discount_factors, _, actions = self.imagine_trajectory( + zs, _, next_zs, rewards, discount_factors, _, actions = self.imagine_trajectory( initial_states) rewards = self.world_model.reward_normalizer(rewards) @@ -728,16 +725,16 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation discount_factors = self.critic.gamma * torch.ones_like(rewards) # Ignore all factors after first is_finished state - discount_factors = torch.cumprod(discount_factors, dim=0).detach() + discount_factors = torch.cumprod(discount_factors, dim=0) # Discounted factors should be shifted as they predict whether next state is terminal # First discount factor on contrary is always 1 as it cannot lead to trajectory finish - discount_factors = torch.cat([torch.ones_like(discount_factors[:1, :]), discount_factors[:-1, :]], dim=0) + discount_factors = torch.cat([torch.ones_like(discount_factors[:1, :]), discount_factors[:-1, :]], dim=0).detach() # vs = self.critic.lambda_return(zs[1:], rewards[:-1], discount_factors[:-1]) vs = rewards[:-1] + self.critic.gamma * self.critic.target_critic(zs[1:]).mode # vs = torch.cat([vs, rewards[-1].unsqueeze(0)], dim=0) - predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) + predicted_vs_dist = self.critic.estimate_value(zs[:-1]) losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(-1) * discount_factors[:-1]).mean() # predicted_vs_dist = self.critic.estimate_value(zs[:-2].detach()) @@ -749,32 +746,33 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer - baseline = self.critic.target_critic(next_zs[1:-2]).mode - advantage = (vs[1:-1] - baseline).detach() - losses_ac['loss_actor_reinforce'] += -(self.rho * (torch.stack([a.log_prob(actions[idx].detach()) for idx, a in enumerate(action_dists[1:-2])]) * discount_factors[1:-2].squeeze()) * advantage.squeeze()).mean() - losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[:-1]*discount_factors[:-2])).mean() + action_dists = self.actor(zs[:-2].detach()) + # baseline = self.critic.target_critic(next_zs[1:-2]).mode + # advantage = (vs[1:-1] - baseline).detach() + losses_ac['loss_actor_reinforce'] += 0 #-(self.rho * (torch.stack([a.log_prob(actions[idx].detach()) for idx, a in enumerate(action_dists[1:])]) * discount_factors[1:-2].squeeze()) * advantage.squeeze()).mean() + losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-2])).mean() def calculate_entropy(dist): # x_t = dist.base_dist.base_dist.rsample() # return -(dist.base_dist.base_dist.log_prob(x_t) - torch.log(1 - torch.tanh(x_t).pow(2) + 1e-6)).sum(1) return -dist.log_prob(dist.rsample((128,))).mean(0) - losses_ac['loss_actor_entropy'] += -(self.eta * - torch.stack([calculate_entropy(a) for a in action_dists[1:-2]]) * discount_factors[1:-2].squeeze()).mean() + losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)).mean() + # torch.stack([calculate_entropy(a) for a in action_dists[1:]]) * discount_factors[1:-2].squeeze()).mean() # losses_ac['loss_actor_entropy'] += (self.eta * # torch.stack([a.entropy() for a in action_dists[1:-1]]) * discount_factors[1:-1].squeeze()).mean() losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used - act_avg = torch.stack([a.sample((128,)).mean(dim=0) for a in action_dists[:-1]]) + act_avg = torch.stack([a.sample((128,)).mean(dim=0) for a in action_dists[1:]]) # act_avg = torch.stack([a.mean for a in action_dists[:-1]]) metrics['actor/avg_val'] = act_avg.mean() - metrics['actor/mode_val'] = torch.stack([torch.mode(a.sample((128,)), 0)[0] for a in action_dists[:-1]]).mean() + metrics['actor/mode_val'] = torch.stack([a.mode for a in action_dists[1:]]).mean() # metrics['actor/avg_sd'] = torch.stack([a.stddev for a in action_dists[:-1]]).mean() - metrics['actor/avg_sd'] = (((torch.stack([a.sample((128,)) for a in action_dists[:-1]], dim=1) - act_avg)**2).mean(0).sqrt()).mean() - metrics['actor/min_val'] = torch.stack([a.sample((128,)) for a in action_dists[:-1]]).min() - metrics['actor/max_val'] = torch.stack([a.sample((128,)) for a in action_dists[:-1]]).max() + metrics['actor/avg_sd'] = (((torch.stack([a.sample((128,)) for a in action_dists[1:]], dim=1) - act_avg)**2).mean(0).sqrt()).mean() + metrics['actor/min_val'] = torch.stack([a.sample((128,)) for a in action_dists[1:]]).min() + metrics['actor/max_val'] = torch.stack([a.sample((128,)) for a in action_dists[1:]]).max() self.actor_optimizer.zero_grad(set_to_none=True) self.critic_optimizer.zero_grad(set_to_none=True) diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index e704556..8473623 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -7,10 +7,11 @@ def fc_nn_generator(input_num: int, num_layers: int, intermediate_activation: t.Type[nn.Module] = nn.ReLU, final_activation: nn.Module = nn.Identity()): + assert num_layers >= 3 layers = [] layers.append(nn.Linear(input_num, hidden_size)) layers.append(nn.ReLU(inplace=True)) - for _ in range(num_layers): + for _ in range(num_layers - 2): layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(intermediate_activation(inplace=True)) layers.append(nn.Linear(hidden_size, output_num)) From 9a43e7fe26188fd68a3797265c27cb5653b1ff32 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 27 Jan 2023 14:21:04 +0000 Subject: [PATCH 032/106] It is working! (confirmed for cheetah and walker) --- rl_sandbox/agents/dreamer_v2.py | 202 ++++++++++++++---------- rl_sandbox/config/agent/dreamer_v2.yaml | 8 +- rl_sandbox/config/config.yaml | 8 +- rl_sandbox/config/env/dm_cheetah.yaml | 8 + rl_sandbox/config/env/dm_quadruped.yaml | 8 + rl_sandbox/config/env/dm_walker.yaml | 8 + rl_sandbox/utils/dists.py | 136 ++++++++++++++++ 7 files changed, 288 insertions(+), 90 deletions(-) create mode 100644 rl_sandbox/config/env/dm_cheetah.yaml create mode 100644 rl_sandbox/config/env/dm_quadruped.yaml create mode 100644 rl_sandbox/config/env/dm_walker.yaml create mode 100644 rl_sandbox/utils/dists.py diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 574069a..7fa33ad 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,4 +1,3 @@ -import itertools import typing as t from collections import defaultdict @@ -14,10 +13,7 @@ from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, Observations, Rewards, TerminationFlags) - -from torchviz import make_dot -from torchrl.modules import TanhNormal, TruncatedNormal - +from rl_sandbox.utils.dists import TruncatedNormal class View(nn.Module): @@ -65,26 +61,28 @@ def forward(self, x): else: return td.TransformedDistribution(dist, self.trans) - class DistLayer(nn.Module): def __init__(self, type: str): super().__init__() match type: case 'mse': - self.dist = lambda x: td.Normal(x, 1.0) + self.dist = lambda x: td.Normal(x.float(), 1.0) case 'normal': self.dist = NormalWithOffset(min_std=0.1) case 'onehot': - self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x) + # Forcing float32 on AMP + self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x.float()) case 'normal_tanh': def get_tanh_normal(x, min_std=0.1): mean, std = x.chunk(2, dim=-1) - return TanhNormal(mean, F.softplus(std) + min_std, upscale=5) + init_std = np.log(np.exp(5) - 1) + raise NotImplementedError() + # return TanhNormal(torch.clamp(mean, -9.0, 9.0).float(), (F.softplus(std + init_std) + min_std).float(), upscale=5) self.dist = get_tanh_normal case 'normal_trunc': def get_trunc_normal(x, min_std=0.1): mean, std = x.chunk(2, dim=-1) - return TruncatedNormal(mean, 2*torch.sigmoid(std) + min_std, upscale=2) + return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float(), a=-1, b=1) self.dist = get_trunc_normal case 'binary': self.dist = lambda x: td.Bernoulli(logits=x) @@ -324,7 +322,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, num_layers=4, intermediate_activation=nn.ELU, final_activation=DistLayer('binary')) - self.reward_normalizer = Normalizer(momentum=0.99, scale=1.0, eps=1e-8) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): determ_state, next_repr_logits = self.recurrent_model.predict_next( @@ -332,9 +330,9 @@ def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor next_repr = Dist(next_repr_logits).rsample().reshape( -1, self.latent_dim * self.latent_classes) - reward = self.reward_predictor( - torch.concat([determ_state.squeeze(0), next_repr], dim=1)).mode - discount_factors = self.discount_predictor(torch.concat([determ_state.squeeze(0), next_repr], dim=1)).sample() + inp = torch.concat([determ_state.squeeze(0), next_repr], dim=-1) + reward = self.reward_predictor(inp).mode + discount_factors = self.discount_predictor(inp).sample() return determ_state, next_repr, reward, discount_factors def get_latent(self, obs: torch.Tensor, action, state): @@ -350,21 +348,25 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, embed = self.encoder(obs) embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) - obs_c = obs.reshape(-1, self.cluster_size, 3, h, w) a_c = a.reshape(-1, self.cluster_size, self.actions_num) r_c = r.reshape(-1, self.cluster_size, 1) d_c = discount.reshape(-1, self.cluster_size, 1) h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) + metrics = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) - def KL(dist1, dist2): + def KL(dist1, dist2, free_nat = True): KL_ = torch.distributions.kl_divergence one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) # TODO: kl_free_avg is used always - kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) - kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) - return self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs) + if free_nat: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) + else: + kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() + kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() + return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) latent_vars = [] determ_vars = [] @@ -376,8 +378,7 @@ def KL(dist1, dist2): for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - x_t, embed_t, a_t, r_t, f_t = obs_c[:, t], embed_c[:, t].unsqueeze( - 0), a_c[:, t].unsqueeze(0), r_c[:, t], d_c[:, t] + embed_t, a_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0) determ_t, prior_stoch_logits, posterior_stoch_logits = self.recurrent_model.forward( h_prev, embed_t, a_t) @@ -396,15 +397,19 @@ def KL(dist1, dist2): r_pred = self.reward_predictor(inp) f_pred = self.discount_predictor(inp) x_r = self.image_predictor(torch.flatten(inp, 0, 1)) + prior_logits = torch.flatten(torch.stack(prior_logits, dim=1), 0, 1) + posterior_logits = torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1) + + losses['loss_reconstruction'] = -x_r.log_prob(obs).mean() + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) - losses['loss_reconstruction'] += -x_r.log_prob(obs).mean() - losses['loss_reward_pred'] += -r_pred.log_prob(r_c).mean() - losses['loss_discount_pred'] += -f_pred.log_prob(d_c).mean() - # NOTE: entropy can be added as metric - losses['loss_kl_reg'] += KL(torch.flatten(torch.stack(prior_logits, dim=1), 0, 1), - torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1)) + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() - return losses, torch.stack(latent_vars, dim=1).reshape(-1, self.latent_dim * self.latent_classes).detach() + return losses, inp.flatten(0, 1).detach(), metrics class ImaginativeCritic(nn.Module): @@ -430,24 +435,28 @@ def __init__(self, discount_factor: float, update_interval: int, 4, intermediate_activation=nn.ELU, final_activation=DistLayer('mse')) + self.target_critic.requires_grad_(False) def update_target(self): if self._update_num == 0: - for target_param, local_param in zip(self.target_critic.parameters(), - self.critic.parameters()): - mix = self.critic_soft_update_fraction - target_param.data.copy_(mix * local_param.data + - (1 - mix) * target_param.data) + self.target_critic.load_state_dict(self.critic.state_dict()) + # for target_param, local_param in zip(self.target_critic.parameters(), + # self.critic.parameters()): + # mix = self.critic_soft_update_fraction + # target_param.data.copy_(mix * local_param.data + + # (1 - mix) * target_param.data) self._update_num = (self._update_num + 1) % self.critic_update_interval def estimate_value(self, z) -> td.Distribution: return self.critic(z) def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): - v_lambdas = [rs[-1] + self.gamma*vs[-1]] - for i in range(vs.shape[0] - 2, -1, -1): - v_lambda = rs[i] + ds[i] * self.gamma * ( - (1 - self.lambda_) * vs[i] + + # Formula is actually slightly different than in paper + # https://github.com/danijar/dreamerv2/issues/25 + v_lambdas = [vs[-1]] + for i in range(rs.shape[0] - 2, -1, -1): + v_lambda = rs[i] + ds[i] * ( + (1 - self.lambda_) * vs[i+1] + self.lambda_ * v_lambdas[-1]) v_lambdas.append(v_lambda) @@ -489,25 +498,25 @@ def __init__( self.cluster_size = batch_cluster_size self.actions_num = actions_num self.rho = actor_reinforce_fraction - if actor_reinforce_fraction != 0: - raise NotImplementedError("Reinforce part is not implemented") + # if actor_reinforce_fraction != 0: + # raise NotImplementedError("Reinforce part is not implemented") self.eta = actor_entropy_scale self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_loss_free_nats).to(device_type) - self.actor = fc_nn_generator(latent_dim * latent_classes, + self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num * 2 if action_type == 'continuous' else actions_num, 400, 4, intermediate_activation=nn.ELU, - final_activation=DistLayer('normal_tanh' if action_type == 'continuous' else 'onehot')).to(device_type) + final_activation=DistLayer('normal_trunc' if action_type == 'continuous' else 'onehot')).to(device_type) self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, critic_value_target_lambda, - latent_dim * latent_classes).to(device_type) + rssm_dim + latent_dim * latent_classes).to(device_type) self.scaler = torch.cuda.amp.GradScaler() self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), @@ -530,15 +539,18 @@ def imagine_trajectory( torch.Tensor, torch.Tensor, torch.Tensor]: if horizon is None: horizon = self.imagination_horizon - world_state = None + world_state = torch.zeros(1, z_0.shape[0], 200, device=z_0.device) zs, actions_dists, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [], [] z = z_0 for i in range(horizon): + # FIXME: if somebody sees it, you have no credibility as programmer + if z.shape[1] == 1224: + z, world_state = z[:, :1024], z[:, 1024:].unsqueeze(0) if precomp_actions is not None: a_dist = None a = precomp_actions[i].unsqueeze(0) else: - a_dist = self.actor(z.detach()) + a_dist = self.actor(torch.cat([world_state.squeeze(), z], dim=-1).detach()) a = a_dist.rsample() world_state, next_z, reward, discount = self.world_model.predict_next( z, a, world_state) @@ -583,8 +595,10 @@ def get_action(self, obs: Observation) -> Action: self._state = (determ, latent_repr_dist.rsample().reshape(-1, 32 * 32).unsqueeze(0)) - actor_dist = self.actor(self._state[1]) - self._last_action = actor_dist.rsample() + actor_dist = self.actor(torch.cat(self._state, dim=-1)) + # FIXME: expl_noise magic number + self._last_action = actor_dist.rsample() + torch.randn_like(self._last_action) * 0.3 + self._last_action = torch.clamp(self._last_action, -1, 1) if False: self._action_probs += actor_dist.base_dist.probs.squeeze() @@ -606,11 +620,11 @@ def _generate_video(self, obs: Observation, actions: list[Action]): else: actions = self.from_np(actions) z_0 = Dist(self.world_model.get_latent(obs, actions[0].unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) - zs, _, _, _, _, determs, _ = self.imagine_trajectory(z_0.squeeze(0), actions[1:], horizon=self.imagination_horizon - 1) + zs, _, _, rews, _, determs, _ = self.imagine_trajectory(z_0.squeeze(0), actions[1:], horizon=self.imagination_horizon - 1) # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() - video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=2)).mode.cpu().detach().numpy() + video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=-1)).mode.cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) - return video_r + return video_r, rews.sum() def _generate_video_with_update(self, obs: list[Observation], init_action: list[Action]): obs = torch.from_numpy(obs.copy()).to(self.device) @@ -623,27 +637,30 @@ def _generate_video_with_update(self, obs: list[Observation], init_action: list[ action = self.from_np(init_action) state = None video = [] + rews = [] for o, a in zip(obs, action): determ, stoch_logits = self.world_model.get_latent(o.unsqueeze(0), a.unsqueeze(0).unsqueeze(0), state) z_0 = Dist(stoch_logits).rsample().reshape(-1, 32 * 32).unsqueeze(0) state = (determ, z_0) - video_r = self.world_model.image_predictor(torch.concat([determ, z_0], dim=-1)).mode.cpu().detach().numpy() + inp = torch.concat([determ, z_0], dim=-1) + video_r = self.world_model.image_predictor(inp).mode.cpu().detach().numpy() + rews.append(self.world_model.reward_predictor(inp).mode.item()) video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) video.append(video_r) - return np.concatenate(video) + return np.concatenate(video), sum(rews) def viz_log(self, rollout, logger, epoch_num): init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 3) - videos_r = np.concatenate([ - self._generate_video(obs_0.copy(), a_0) for obs_0, a_0 in zip( - rollout.next_states[init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) - ], axis=3) + real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon].sum() for idx in init_indeces] + videos, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0) for obs_0, a_0 in zip( + rollout.next_states[init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces])]) + videos_r = np.concatenate(videos, axis=3) - videos_r_update = np.concatenate([ - self._generate_video_with_update(obs_0.copy(), a_0) for obs_0, a_0 in zip( + videos_update, imagined_update_rewards = zip(*[self._generate_video_with_update(obs_0.copy(), a_0) for obs_0, a_0 in zip( [rollout.next_states[idx:idx+ self.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) - ], axis=3) + ]) + videos_r_update = np.concatenate(videos_update, axis=3) videos = np.concatenate([ rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( @@ -668,6 +685,13 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + rewards_err = torch.Tensor([torch.abs(imagined_rewards[i] - real_rewards[i]) for i in range(len(imagined_rewards))]).mean() + rewards_update_err = np.mean([np.abs(imagined_update_rewards[i] - real_rewards[i]) for i in range(len(imagined_rewards))]) + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + logger.add_scalar('val/img_update_reward_err', rewards_update_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0], epoch_num) + def from_np(self, arr: np.ndarray): arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr return arr.to(self.device, non_blocking=True) @@ -682,10 +706,13 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) + number_of_zero_discounts = (1 - discount_factors).sum() + if number_of_zero_discounts > 0: + pass # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=True): - losses, discovered_latents = self.world_model.calculate_loss( + losses, discovered_latents, wm_metrics = self.world_model.calculate_loss( next_obs, a, r, discount_factors) # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device @@ -707,6 +734,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation metrics = defaultdict( lambda: torch.zeros(1).to(self.device)) + metrics |= wm_metrics with torch.cuda.amp.autocast(enabled=True): # idx = torch.randperm(discovered_latents.size(0)) @@ -717,47 +745,53 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac = defaultdict( lambda: torch.zeros(1).to(self.device)) - zs, _, next_zs, rewards, discount_factors, _, actions = self.imagine_trajectory( + zs, _, next_zs, rewards, discount_factors, determs, actions = self.imagine_trajectory( initial_states) + zs = torch.cat([determs.squeeze(), zs], dim=-1) rewards = self.world_model.reward_normalizer(rewards) # Discount prediction is disabled for dmc vision in Dreamer + # as trajectory will not abruptly stop discount_factors = self.critic.gamma * torch.ones_like(rewards) - # Ignore all factors after first is_finished state - discount_factors = torch.cumprod(discount_factors, dim=0) - # Discounted factors should be shifted as they predict whether next state is terminal # First discount factor on contrary is always 1 as it cannot lead to trajectory finish - discount_factors = torch.cat([torch.ones_like(discount_factors[:1, :]), discount_factors[:-1, :]], dim=0).detach() + discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-2]], dim=0).detach() + + vs = self.critic.lambda_return(zs, rewards[:-1], discount_factors) - # vs = self.critic.lambda_return(zs[1:], rewards[:-1], discount_factors[:-1]) - vs = rewards[:-1] + self.critic.gamma * self.critic.target_critic(zs[1:]).mode - # vs = torch.cat([vs, rewards[-1].unsqueeze(0)], dim=0) - predicted_vs_dist = self.critic.estimate_value(zs[:-1]) - losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(-1) * discount_factors[:-1]).mean() + # Ignore all factors after first is_finished state + discount_factors = torch.cumprod(discount_factors, dim=0) + + # vs = rewards[:-1] + self.critic.gamma * self.critic.target_critic(zs[1:]) * discount_factors[1:] + predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) + # losses_ac['loss_critic'] = F.mse_loss(predicted_vs_dist, vs[:-1].detach()) + losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach())).mean() # predicted_vs_dist = self.critic.estimate_value(zs[:-2].detach()) # losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs[:-1].detach()).unsqueeze(-1) * discount_factors[:-2]).mean() metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mode.mean() + # metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mean() metrics['critic/avg_lambda_value'] = vs.mean() metrics['critic/avg_predicted_value'] = predicted_vs_dist.mode.mean() + # metrics['critic/avg_predicted_value'] = predicted_vs_dist.mean() # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer - action_dists = self.actor(zs[:-2].detach()) - # baseline = self.critic.target_critic(next_zs[1:-2]).mode - # advantage = (vs[1:-1] - baseline).detach() - losses_ac['loss_actor_reinforce'] += 0 #-(self.rho * (torch.stack([a.log_prob(actions[idx].detach()) for idx, a in enumerate(action_dists[1:])]) * discount_factors[1:-2].squeeze()) * advantage.squeeze()).mean() - losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-2])).mean() + action_dists = self.actor(zs[1:-1].detach()) + baseline = self.critic.target_critic(zs[1:-1]).mode + advantage = (vs[1:] - baseline).detach() + losses_ac['loss_actor_reinforce'] += 0# -(self.rho * action_dists.base_dist.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() + losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-1])).mean() def calculate_entropy(dist): - # x_t = dist.base_dist.base_dist.rsample() - # return -(dist.base_dist.base_dist.log_prob(x_t) - torch.log(1 - torch.tanh(x_t).pow(2) + 1e-6)).sum(1) - return -dist.log_prob(dist.rsample((128,))).mean(0) + # return -dist.base_dist.log_prob(dist.rsample((1024,))).mean(0).unsqueeze(2) + return dist.entropy().unsqueeze(2) + # return dist.base_dist.base_dist.entropy().unsqueeze(2) + + losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)*discount_factors[:-1]).mean() - losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)).mean() # torch.stack([calculate_entropy(a) for a in action_dists[1:]]) * discount_factors[1:-2].squeeze()).mean() # losses_ac['loss_actor_entropy'] += (self.eta * # torch.stack([a.entropy() for a in action_dists[1:-1]]) * discount_factors[1:-1].squeeze()).mean() @@ -765,14 +799,16 @@ def calculate_entropy(dist): losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used - act_avg = torch.stack([a.sample((128,)).mean(dim=0) for a in action_dists[1:]]) + sample = action_dists.rsample((128,)) + act_avg = sample.mean(0) # act_avg = torch.stack([a.mean for a in action_dists[:-1]]) metrics['actor/avg_val'] = act_avg.mean() - metrics['actor/mode_val'] = torch.stack([a.mode for a in action_dists[1:]]).mean() + # metrics['actor/mode_val'] = action_dists.mode.mean() + metrics['actor/mean_val'] = action_dists.mean.mean() # metrics['actor/avg_sd'] = torch.stack([a.stddev for a in action_dists[:-1]]).mean() - metrics['actor/avg_sd'] = (((torch.stack([a.sample((128,)) for a in action_dists[1:]], dim=1) - act_avg)**2).mean(0).sqrt()).mean() - metrics['actor/min_val'] = torch.stack([a.sample((128,)) for a in action_dists[1:]]).min() - metrics['actor/max_val'] = torch.stack([a.sample((128,)) for a in action_dists[1:]]).max() + metrics['actor/avg_sd'] = (((sample - act_avg)**2).mean(0).sqrt()).mean() + metrics['actor/min_val'] = sample.min() + metrics['actor/max_val'] = sample.max() self.actor_optimizer.zero_grad(set_to_none=True) self.critic_optimizer.zero_grad(set_to_none=True) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index ee50a82..955edb4 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -1,6 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 # World model parameters -batch_cluster_size: 32 +batch_cluster_size: 50 latent_dim: 32 latent_classes: 32 rssm_dim: 200 @@ -10,13 +10,13 @@ kl_loss_free_nats: 1.0 world_model_lr: 3e-4 # ActorCritic parameters -discount_factor: 0.99 -imagination_horizon: 15 +discount_factor: 0.999 +imagination_horizon: 16 actor_lr: 8e-5 # mixing of reinforce and maximizing value func # for dm_control it is zero in Dreamer (Atari 1) -actor_reinforce_fraction: 0 +actor_reinforce_fraction: 0.0 actor_entropy_scale: 1e-4 critic_lr: 8e-5 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 9ad962c..b9b11c8 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,17 +1,19 @@ defaults: - agent/dreamer_v2 - #- env/dm_quadruped + #- env/dm_walker - env/dm_cartpole + #- env/dm_quadruped + #- env/dm_cheetah - _self_ seed: 42 device_type: cuda training: - steps: 5e5 + steps: 1e6 prefill: 1000 pretrain: 100 - batch_size: 1024 + batch_size: 2500 gradient_steps_per_step: 5 save_checkpoint_every: 1e5 val_logs_every: 2.5e3 diff --git a/rl_sandbox/config/env/dm_cheetah.yaml b/rl_sandbox/config/env/dm_cheetah.yaml new file mode 100644 index 0000000..a6d1490 --- /dev/null +++ b/rl_sandbox/config/env/dm_cheetah.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: cheetah +task_name: run +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_quadruped.yaml b/rl_sandbox/config/env/dm_quadruped.yaml new file mode 100644 index 0000000..aa5e541 --- /dev/null +++ b/rl_sandbox/config/env/dm_quadruped.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: quadruped +task_name: walk +run_on_pixels: true +obs_res: [64, 64] +camera_id: 2 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_walker.yaml b/rl_sandbox/config/env/dm_walker.yaml new file mode 100644 index 0000000..68d8cf5 --- /dev/null +++ b/rl_sandbox/config/env/dm_walker.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: walker +task_name: walk +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/utils/dists.py b/rl_sandbox/utils/dists.py new file mode 100644 index 0000000..57214fa --- /dev/null +++ b/rl_sandbox/utils/dists.py @@ -0,0 +1,136 @@ +# Taken from https://raw.githubusercontent.com/toshas/torch_truncnorm/main/TruncatedNormal.py +import math +from numbers import Number + +import torch +from torch.distributions import Distribution, constraints +from torch.distributions.utils import broadcast_all + +CONST_SQRT_2 = math.sqrt(2) +CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) +CONST_INV_SQRT_2 = 1 / math.sqrt(2) +CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) +CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) + + +class TruncatedStandardNormal(Distribution): + """ + Truncated Standard Normal distribution + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + arg_constraints = { + 'a': constraints.real, + 'b': constraints.real, + } + has_rsample = True + + def __init__(self, a, b, validate_args=None): + self.a, self.b = broadcast_all(a, b) + if isinstance(a, Number) and isinstance(b, Number): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + super(TruncatedStandardNormal, self).__init__(batch_shape, validate_args=validate_args) + if self.a.dtype != self.b.dtype: + raise ValueError('Truncation bounds types are different') + if any((self.a >= self.b).view(-1,).tolist()): + raise ValueError('Incorrect truncation range') + eps = torch.finfo(self.a.dtype).eps + self._dtype_min_gt_0 = eps + self._dtype_max_lt_1 = 1 - eps + self._little_phi_a = self._little_phi(self.a) + self._little_phi_b = self._little_phi(self.b) + self._big_phi_a = self._big_phi(self.a) + self._big_phi_b = self._big_phi(self.b) + self._Z = (self._big_phi_b - self._big_phi_a).clamp_min(eps) + self._log_Z = self._Z.log() + little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) + little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) + self._lpbb_m_lpaa_d_Z = (self._little_phi_b * little_phi_coeff_b - self._little_phi_a * little_phi_coeff_a) / self._Z + self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z + self._variance = 1 - self._lpbb_m_lpaa_d_Z - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 + self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z + + @constraints.dependent_property + def support(self): + return constraints.interval(self.a, self.b) + + @property + def mean(self): + return self._mean + + @property + def variance(self): + return self._variance + + def entropy(self): + return self._entropy + + @property + def auc(self): + return self._Z + + @staticmethod + def _little_phi(x): + return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI + + @staticmethod + def _big_phi(x): + return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) + + @staticmethod + def _inv_big_phi(x): + return CONST_SQRT_2 * (2 * x - 1).erfinv() + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) + + def icdf(self, value): + return self._inv_big_phi(self._big_phi_a + value * self._Z) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5 + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1) + return self.icdf(p) + + +class TruncatedNormal(TruncatedStandardNormal): + """ + Truncated Normal distribution + https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + """ + + has_rsample = True + + def __init__(self, loc, scale, a, b, validate_args=None): + self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) + a = (a - self.loc) / self.scale + b = (b - self.loc) / self.scale + super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) + self._log_scale = self.scale.log() + self._mean = self._mean * self.scale + self.loc + self._variance = self._variance * self.scale ** 2 + self._entropy += self._log_scale + + def _to_std_rv(self, value): + return (value - self.loc) / self.scale + + def _from_std_rv(self, value): + return value * self.scale + self.loc + + def cdf(self, value): + return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) + + def icdf(self, value): + return self._from_std_rv(super(TruncatedNormal, self).icdf(value)) + + def log_prob(self, value): + return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale From b12adad6bb83d18ceb5dc5c8d378316336ee5d78 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 27 Jan 2023 17:12:54 +0000 Subject: [PATCH 033/106] Changed income data to account for 1 frame, fixed rewards/discount_factor prediction --- rl_sandbox/agents/dreamer_v2.py | 39 +++++++------------------- rl_sandbox/train.py | 8 +++--- rl_sandbox/utils/replay_buffer.py | 12 +++++--- rl_sandbox/utils/rollout_generation.py | 4 ++- 4 files changed, 25 insertions(+), 38 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 7fa33ad..78eafd9 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -12,7 +12,7 @@ from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, Observations, Rewards, - TerminationFlags) + TerminationFlags, IsFirstFlags) from rl_sandbox.utils.dists import TruncatedNormal class View(nn.Module): @@ -278,16 +278,11 @@ def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): def forward(self, x): self.update(x) - val = self.trans(x) - return val + return (x / (self.mag + self.eps))*self.scale def update(self, x): self.mag = self.momentum * self.mag.to(x.device) + (1 - self.momentum) * (x.abs().mean()).detach() - def trans(self, x): - x = x / (self.mag + self.eps) - return x*self.scale - class WorldModel(nn.Module): @@ -342,7 +337,7 @@ def get_latent(self, obs: torch.Tensor, action, state): return determ, latent_repr_logits def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - discount: torch.Tensor): + discount: torch.Tensor, first: torch.Tensor): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -351,6 +346,7 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, a_c = a.reshape(-1, self.cluster_size, self.actions_num) r_c = r.reshape(-1, self.cluster_size, 1) d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) @@ -378,7 +374,8 @@ def KL(dist1, dist2, free_nat = True): for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - embed_t, a_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0) + embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + a_t = a_t * (1 - first_t) determ_t, prior_stoch_logits, posterior_stoch_logits = self.recurrent_model.forward( h_prev, embed_t, a_t) @@ -596,9 +593,6 @@ def get_action(self, obs: Observation) -> Action: 32 * 32).unsqueeze(0)) actor_dist = self.actor(torch.cat(self._state, dim=-1)) - # FIXME: expl_noise magic number - self._last_action = actor_dist.rsample() + torch.randn_like(self._last_action) * 0.3 - self._last_action = torch.clamp(self._last_action, -1, 1) if False: self._action_probs += actor_dist.base_dist.probs.squeeze() @@ -697,7 +691,7 @@ def from_np(self, arr: np.ndarray): return arr.to(self.device, non_blocking=True) def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observations, - is_finished: TerminationFlags): + is_finished: TerminationFlags, is_first: IsFirstFlags): obs = self.preprocess_obs(self.from_np(obs)) a = self.from_np(a).to(torch.int64) @@ -706,6 +700,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) + first_flags = self.from_np(is_first).type(torch.float32) + number_of_zero_discounts = (1 - discount_factors).sum() if number_of_zero_discounts > 0: pass @@ -713,7 +709,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=True): losses, discovered_latents, wm_metrics = self.world_model.calculate_loss( - next_obs, a, r, discount_factors) + obs, a, r, discount_factors, first_flags) # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device # world_model_loss = torch.Tensor(1).to(self.device) @@ -763,19 +759,12 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # Ignore all factors after first is_finished state discount_factors = torch.cumprod(discount_factors, dim=0) - # vs = rewards[:-1] + self.critic.gamma * self.critic.target_critic(zs[1:]) * discount_factors[1:] predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) - # losses_ac['loss_critic'] = F.mse_loss(predicted_vs_dist, vs[:-1].detach()) losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach())).mean() - # predicted_vs_dist = self.critic.estimate_value(zs[:-2].detach()) - # losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs[:-1].detach()).unsqueeze(-1) * discount_factors[:-2]).mean() - metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mode.mean() - # metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mean() metrics['critic/avg_lambda_value'] = vs.mean() metrics['critic/avg_predicted_value'] = predicted_vs_dist.mode.mean() - # metrics['critic/avg_predicted_value'] = predicted_vs_dist.mean() # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer @@ -786,26 +775,18 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-1])).mean() def calculate_entropy(dist): - # return -dist.base_dist.log_prob(dist.rsample((1024,))).mean(0).unsqueeze(2) return dist.entropy().unsqueeze(2) # return dist.base_dist.base_dist.entropy().unsqueeze(2) losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)*discount_factors[:-1]).mean() - - # torch.stack([calculate_entropy(a) for a in action_dists[1:]]) * discount_factors[1:-2].squeeze()).mean() - # losses_ac['loss_actor_entropy'] += (self.eta * - # torch.stack([a.entropy() for a in action_dists[1:-1]]) * discount_factors[1:-1].squeeze()).mean() - losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used sample = action_dists.rsample((128,)) act_avg = sample.mean(0) - # act_avg = torch.stack([a.mean for a in action_dists[:-1]]) metrics['actor/avg_val'] = act_avg.mean() # metrics['actor/mode_val'] = action_dists.mode.mean() metrics['actor/mean_val'] = action_dists.mean.mean() - # metrics['actor/avg_sd'] = torch.stack([a.stddev for a in action_dists[:-1]]).mean() metrics['actor/avg_sd'] = (((sample - act_avg)**2).mean(0).sqrt()).mean() metrics['actor/min_val'] = sample.min() metrics['actor/max_val'] = sample.max() diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index cab1c5a..8490d1b 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -68,9 +68,9 @@ def main(cfg: DictConfig): with_stack=True) if cfg.debug.profiler else None for i in tqdm(range(int(cfg.training.pretrain)), desc='Pretraining'): - s, a, r, n, f = buff.sample(cfg.training.batch_size, + s, a, r, n, f, first = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get('batch_cluster_size', 1)) - losses = agent.train(s, a, r, n, f) + losses = agent.train(s, a, r, n, f, first) for loss_name, loss in losses.items(): writer.add_scalar(f'pre_train/{loss_name}', loss, i) @@ -106,10 +106,10 @@ def main(cfg: DictConfig): if global_step % cfg.training.gradient_steps_per_step == 0: # NOTE: unintuitive that batch_size is now number of total # samples, but not amount of sequences for recurrent model - s, a, r, n, f = buff.sample(cfg.training.batch_size, + s, a, r, n, f, first = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get('batch_cluster_size', 1)) - losses = agent.train(s, a, r, n, f) + losses = agent.train(s, a, r, n, f, first) if cfg.debug.profiler: prof.step() # NOTE: Do not forget to run test with every step to check for outliers diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index e0b432d..a483898 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -14,6 +14,7 @@ Actions = NDArray[Shape["*,*"], Int] Rewards = NDArray[Shape["*"], Float] TerminationFlags = NDArray[Shape["*"], Bool] +IsFirstFlags = TerminationFlags @dataclass @@ -31,7 +32,7 @@ def __len__(self): # TODO: make buffer concurrent-friendly class ReplayBuffer: - def __init__(self, max_len=2e5): + def __init__(self, max_len=2e6): self.rollouts: deque[Rollout] = deque() self.rollouts_len: deque[int] = deque() self.curr_rollout = None @@ -81,10 +82,10 @@ def sample( self, batch_size: int, cluster_size: int = 1 - ) -> tuple[States, Actions, Rewards, States, TerminationFlags]: + ) -> tuple[States, Actions, Rewards, States, TerminationFlags, IsFirstFlags]: seq_num = batch_size // cluster_size # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me - s, a, r, n, t = [], [], [], [], [] + s, a, r, n, t, is_first = [], [], [], [], [], [] do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > cluster_size tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), @@ -110,6 +111,9 @@ def sample( else: actions = rollout.actions[s_idx:s_idx + cluster_size] + is_first.append(np.zeros(cluster_size)) + if s_idx == 0: + is_first[-1][0] = 1 s.append(rollout.states[s_idx:s_idx + cluster_size]) a.append(actions) r.append(rollout.rewards[s_idx:s_idx + cluster_size]) @@ -121,4 +125,4 @@ def sample( n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size - 1]) n.append(rollout.next_states) return (np.concatenate(s), np.concatenate(a), np.concatenate(r, dtype=np.float32), - np.concatenate(n), np.concatenate(t)) + np.concatenate(n), np.concatenate(t), np.concatenate(is_first)) diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index a5446f1..be9eb3f 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -54,6 +54,7 @@ def iter_rollout( state, _, terminated = unpack(env.reset()) agent.reset() + prev_action = np.zeros_like(agent.get_action(state)) while not terminated: action = agent.get_action(state) @@ -62,8 +63,9 @@ def iter_rollout( # FIXME: will break for non-DM obs = env.render() if collect_obs else None # if collect_obs and isinstance(env, dmEnv): - yield state, action, reward, new_state, terminated, obs + yield state, prev_action, reward, new_state, terminated, obs state = new_state + prev_action = action def collect_rollout(env: Env, From 42b448bfe7ce5376dd8b032916ee250a22af8d87 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 28 Jan 2023 12:48:41 +0000 Subject: [PATCH 034/106] Changed video logging (5 steps to update determ state, show error) --- rl_sandbox/agents/dreamer_v2.py | 60 ++++++++++--------------- rl_sandbox/config/agent/dreamer_v2.yaml | 2 +- 2 files changed, 24 insertions(+), 38 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 78eafd9..d08f09f 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -604,35 +604,18 @@ def get_action(self, obs: Observation) -> Action: else: return self._last_action.squeeze().detach().cpu().numpy() - def _generate_video(self, obs: Observation, actions: list[Action]): - obs = torch.from_numpy(obs.copy()).to(self.device) - obs = self.preprocess_obs(obs).unsqueeze(0) - - if False: - actions = F.one_hot(self.from_np(actions).to(torch.int64), - num_classes=self.actions_num).squeeze() - else: - actions = self.from_np(actions) - z_0 = Dist(self.world_model.get_latent(obs, actions[0].unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) - zs, _, _, rews, _, determs, _ = self.imagine_trajectory(z_0.squeeze(0), actions[1:], horizon=self.imagination_horizon - 1) - # video_r = self.world_model.image_predictor(torch.concat([determs, zs], dim=2)).rsample().cpu().detach().numpy() - video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=-1)).mode.cpu().detach().numpy() - video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) - return video_r, rews.sum() - - def _generate_video_with_update(self, obs: list[Observation], init_action: list[Action]): + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs) - - if False: - action = F.one_hot(self.from_np(init_action).to(torch.int64), - num_classes=self.actions_num).squeeze() - else: - action = self.from_np(init_action) + actions = self.from_np(actions) state = None video = [] rews = [] - for o, a in zip(obs, action): + + z_0 = Dist(self.world_model.get_latent(obs[0].unsqueeze(0), actions[0].unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx >= update_num: + break determ, stoch_logits = self.world_model.get_latent(o.unsqueeze(0), a.unsqueeze(0).unsqueeze(0), state) z_0 = Dist(stoch_logits).rsample().reshape(-1, 32 * 32).unsqueeze(0) state = (determ, z_0) @@ -641,27 +624,32 @@ def _generate_video_with_update(self, obs: list[Observation], init_action: list[ rews.append(self.world_model.reward_predictor(inp).mode.item()) video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) video.append(video_r) + + if update_num < len(obs): + zs, _, _, rews, _, determs, _ = self.imagine_trajectory(z_0.squeeze(0), actions[update_num+1:], horizon=self.imagination_horizon - 1 - update_num) + video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=-1)).mode.cpu().detach().numpy() + video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) + video.append(video_r) + return np.concatenate(video), sum(rews) def viz_log(self, rollout, logger, epoch_num): init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 3) + + videos = np.concatenate([ + rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( + 0, 3, 1, 2) for init_idx in init_indeces + ], axis=3) + real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon].sum() for idx in init_indeces] - videos, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0) for obs_0, a_0 in zip( - rollout.next_states[init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces])]) - videos_r = np.concatenate(videos, axis=3) - videos_update, imagined_update_rewards = zip(*[self._generate_video_with_update(obs_0.copy(), a_0) for obs_0, a_0 in zip( + videos_r, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.imagination_horizon//3) for obs_0, a_0 in zip( [rollout.next_states[idx:idx+ self.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) ]) - videos_r_update = np.concatenate(videos_update, axis=3) + videos_r = np.concatenate(videos_r, axis=3) - videos = np.concatenate([ - rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( - 0, 3, 1, 2) for init_idx in init_indeces - ], - axis=3) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r_update, videos_r], axis=2), 0) + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r)], axis=2), 0) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) @@ -680,9 +668,7 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) rewards_err = torch.Tensor([torch.abs(imagined_rewards[i] - real_rewards[i]) for i in range(len(imagined_rewards))]).mean() - rewards_update_err = np.mean([np.abs(imagined_update_rewards[i] - real_rewards[i]) for i in range(len(imagined_rewards))]) logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) - logger.add_scalar('val/img_update_reward_err', rewards_update_err.item(), epoch_num) logger.add_scalar(f'val/reward', real_rewards[0], epoch_num) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 955edb4..801e643 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -11,7 +11,7 @@ world_model_lr: 3e-4 # ActorCritic parameters discount_factor: 0.999 -imagination_horizon: 16 +imagination_horizon: 15 actor_lr: 8e-5 # mixing of reinforce and maximizing value func From 9d3d00bf9fff12a87f18b5536344b965f57160d9 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 28 Jan 2023 16:12:31 +0000 Subject: [PATCH 035/106] Added state abstraction which simplified a lot ! --- pyproject.toml | 2 + rl_sandbox/agents/dreamer_v2.py | 249 +++++++++++++++----------------- rl_sandbox/train.py | 4 + 3 files changed, 123 insertions(+), 132 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4bdc92..60c8601 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ unpackable = '^0.0.4' hydra-core = "^1.2.0" matplotlib = "^3.0.0" webdataset = "^0.2.20" +jaxtyping = '^0.2.0' +lovely_tensors = '^0.1.10' [tool.yapf] based_on_style = "pep8" diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index d08f09f..242f343 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,5 +1,6 @@ import typing as t from collections import defaultdict +from dataclasses import dataclass import matplotlib.pyplot as plt import numpy as np @@ -7,6 +8,7 @@ import torch.distributions as td from torch import nn from torch.nn import functional as F +from jaxtyping import Float, Bool from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator @@ -95,6 +97,35 @@ def forward(self, x): def Dist(val): return DistLayer('onehot')(val) +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch determ'] + stoch_logits: Float[torch.Tensor, 'seq batch latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch stoch_dim']] = None + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:2] + (-1,)) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim = 0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs) class RSSM(nn.Module): """ @@ -145,7 +176,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' - View((-1, latent_dim, self.latent_classes))) for _ in range(self.ensemble_num) + View((1, -1, latent_dim, self.latent_classes))) for _ in range(self.ensemble_num) ]) # For observation we do not have ensemble @@ -157,9 +188,9 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' - View((-1, latent_dim, self.latent_classes))) + View((1, -1, latent_dim, self.latent_classes))) - def estimate_stochastic_latent(self, prev_determ): + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of # taking only one random between all ensembles @@ -168,54 +199,33 @@ def estimate_stochastic_latent(self, prev_determ): return dists_per_model[idx] def predict_next(self, - stoch_latent, - action, - deter_state: t.Optional[torch.Tensor] = None): - # FIXME: Move outside of rssm to omit checking - if deter_state is None: - deter_state = torch.zeros(*stoch_latent.shape[:2], self.hidden_size).to( - next(self.stoch_net.parameters()).device) - x = self.pre_determ_recurrent(torch.concat([stoch_latent, action], dim=-1)) + prev_state: State, + action) -> State: + x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action], dim=-1)) # NOTE: x and determ are actually the same value if sequence of 1 is inserted - x, determ = self.determ_recurrent(x, deter_state) + x, determ = self.determ_recurrent(x, prev_state.determ) # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(x) - return determ, predicted_stoch_logits + return State(determ, predicted_stoch_logits) - def update_current(self, determ, embed): # Dreamer 'obs_out' - return self.stoch_net(torch.concat([determ, embed], dim=-1)) + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1))) - def forward(self, h_prev: t.Optional[tuple[torch.Tensor, torch.Tensor]], embed, - action): + def forward(self, h_prev: State, embed, + action) -> tuple[State, State]: """ 'h' <- internal state of the world 'z' <- latent embedding of current observation 'a' <- action taken on prev step Returns 'h_next' <- the next next of the world """ + prior = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior + - # FIXME: Use zero vector for prev_state of first - # Move outside of rssm to omit checking - if h_prev is None: - h_prev = (torch.zeros(( - *embed.shape[:-1], - self.hidden_size, - ), - device=next(self.stoch_net.parameters()).device), - torch.zeros( - (*action.shape[:-1], self.latent_dim * self.latent_classes), - device=next(self.stoch_net.parameters()).device)) - deter_prev, stoch_prev = h_prev - determ, prior_stoch_logits = self.predict_next(stoch_prev, - action, - deter_state=deter_prev) - posterior_stoch_logits = self.update_current(determ, embed) - - return [determ, prior_stoch_logits, posterior_stoch_logits] - - -# NOTE: residual blocks are not used inside dreamer class Encoder(nn.Module): def __init__(self, kernel_sizes=[4, 4, 4, 4]): @@ -319,22 +329,26 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) - def predict_next(self, latent_repr, action, world_state: t.Optional[torch.Tensor]): - determ_state, next_repr_logits = self.recurrent_model.predict_next( - latent_repr.unsqueeze(0), action.unsqueeze(0), world_state) + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) - next_repr = Dist(next_repr_logits).rsample().reshape( - -1, self.latent_dim * self.latent_classes) - inp = torch.concat([determ_state.squeeze(0), next_repr], dim=-1) - reward = self.reward_predictor(inp).mode - discount_factors = self.discount_predictor(inp).sample() - return determ_state, next_repr, reward, discount_factors + def predict_next(self, prev_state: State, action): + prior = self.recurrent_model.predict_next(prev_state, action) - def get_latent(self, obs: torch.Tensor, action, state): - embed = self.encoder(obs) - determ, _, latent_repr_logits = self.recurrent_model.forward(state, embed.unsqueeze(0), + reward = self.reward_predictor(prior.combined).mode + discount_factors = self.discount_predictor(prior.combined).sample() + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: + if state is None: + state = self.get_initial_state() + embed = self.encoder(obs.unsqueeze(0)) + _, posterior = self.recurrent_model.forward(state, embed.unsqueeze(0), action) - return determ, latent_repr_logits + return posterior def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor): @@ -348,7 +362,6 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, d_c = discount.reshape(-1, self.cluster_size, 1) first_c = first.reshape(-1, self.cluster_size, 1) - h_prev = None losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) metrics = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) @@ -364,38 +377,30 @@ def KL(dist1, dist2, free_nat = True): kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) - latent_vars = [] - determ_vars = [] - prior_logits = [] - posterior_logits = [] - - # inps = [] - # reconstructed = [] + priors = [] + posteriors = [] + prev_state = self.get_initial_state(b // self.cluster_size) for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) a_t = a_t * (1 - first_t) - determ_t, prior_stoch_logits, posterior_stoch_logits = self.recurrent_model.forward( - h_prev, embed_t, a_t) - posterior_stoch = Dist(posterior_stoch_logits).rsample().reshape( - -1, self.latent_dim * self.latent_classes) + prior, posterior = self.recurrent_model.forward(prev_state, embed_t, a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) - h_prev = [determ_t, posterior_stoch.unsqueeze(0)] - determ_vars.append(determ_t.squeeze(0)) - latent_vars.append(posterior_stoch) + posterior = State.stack(posteriors) + prior = State.stack(priors) - prior_logits.append(prior_stoch_logits) - posterior_logits.append(posterior_stoch_logits) + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) - # inp = torch.concat([determ_vars.squeeze(0), posterior_stoch], dim=1) - inp = torch.concat([torch.stack(determ_vars, dim=1), torch.stack(latent_vars, dim=1)], dim=-1) - r_pred = self.reward_predictor(inp) - f_pred = self.discount_predictor(inp) - x_r = self.image_predictor(torch.flatten(inp, 0, 1)) - prior_logits = torch.flatten(torch.stack(prior_logits, dim=1), 0, 1) - posterior_logits = torch.flatten(torch.stack(posterior_logits, dim=1), 0, 1) + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits losses['loss_reconstruction'] = -x_r.log_prob(obs).mean() losses['loss_reward_pred'] = -r_pred.log_prob(r_c).mean() @@ -406,7 +411,7 @@ def KL(dist1, dist2, free_nat = True): metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() - return losses, inp.flatten(0, 1).detach(), metrics + return losses, posterior, metrics class ImaginativeCritic(nn.Module): @@ -451,7 +456,7 @@ def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): # Formula is actually slightly different than in paper # https://github.com/danijar/dreamerv2/issues/25 v_lambdas = [vs[-1]] - for i in range(rs.shape[0] - 2, -1, -1): + for i in range(rs.shape[0] - 1, -1, -1): v_lambda = rs[i] + ds[i] * ( (1 - self.lambda_) * vs[i+1] + self.lambda_ * v_lambdas[-1]) @@ -531,42 +536,31 @@ def __init__( self.reset() def imagine_trajectory( - self, z_0, precomp_actions: t.Optional[list[Action]] = None, horizon: t.Optional[int] = None - ) -> tuple[torch.Tensor, torch.distributions.Distribution, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor, torch.Tensor]: + self, init_state: State, precomp_actions: t.Optional[list[Action]] = None, horizon: t.Optional[int] = None + ) -> tuple[State, torch.Tensor, torch.Tensor, + torch.Tensor]: if horizon is None: horizon = self.imagination_horizon - world_state = torch.zeros(1, z_0.shape[0], 200, device=z_0.device) - zs, actions_dists, actions, next_zs, rewards, ts, determs = [], [], [], [], [], [], [] - z = z_0 + states, actions, rewards, ts = [], [], [], [] + prev_state = init_state for i in range(horizon): - # FIXME: if somebody sees it, you have no credibility as programmer - if z.shape[1] == 1224: - z, world_state = z[:, :1024], z[:, 1024:].unsqueeze(0) if precomp_actions is not None: - a_dist = None a = precomp_actions[i].unsqueeze(0) else: - a_dist = self.actor(torch.cat([world_state.squeeze(), z], dim=-1).detach()) + a_dist = self.actor(prev_state.combined.detach()) a = a_dist.rsample() - world_state, next_z, reward, discount = self.world_model.predict_next( - z, a, world_state) + prior, reward, discount = self.world_model.predict_next(prev_state, a) + prev_state = prior - zs.append(z) - actions_dists.append(a_dist) - next_zs.append(next_z) + states.append(prior) rewards.append(reward) ts.append(discount) - determs.append(world_state[0]) actions.append(a) - z = next_z - return (torch.stack(zs), actions_dists, torch.stack(next_zs), - torch.stack(rewards), torch.stack(ts), torch.stack(determs), torch.stack(actions)) + return (State.stack(states), torch.cat(actions), torch.cat(rewards), torch.cat(ts)) def reset(self): - self._state = None - # FIXME: instead of zero, it should be mode of distribution + self._state = self.world_model.get_initial_state() self._last_action = torch.zeros((1, 1, self.actions_num), device=self.device) self._latent_probs = torch.zeros((32, 32), device=self.device) self._action_probs = torch.zeros((self.actions_num), device=self.device) @@ -584,19 +578,16 @@ def preprocess_obs(obs: torch.Tensor): def get_action(self, obs: Observation) -> Action: # NOTE: pytorch fails without .copy() only when get_action is called obs = torch.from_numpy(obs.copy()).to(self.device) - obs = self.preprocess_obs(obs).unsqueeze(0) + obs = self.preprocess_obs(obs) - determ, latent_repr_logits = self.world_model.get_latent(obs, self._last_action, - self._state) - latent_repr_dist = Dist(latent_repr_logits) - self._state = (determ, latent_repr_dist.rsample().reshape(-1, - 32 * 32).unsqueeze(0)) + self._state = self.world_model.get_latent(obs, self._last_action, self._state) - actor_dist = self.actor(torch.cat(self._state, dim=-1)) + actor_dist = self.actor(self._state.combined) + self._last_action = actor_dist.sample() if False: self._action_probs += actor_dist.base_dist.probs.squeeze() - self._latent_probs += latent_repr_dist.base_dist.probs.squeeze() + self._latent_probs += self._state.stoch_dist.base_dist.probs.squeeze() self._stored_steps += 1 if False: @@ -608,26 +599,23 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs) actions = self.from_np(actions) - state = None video = [] rews = [] - z_0 = Dist(self.world_model.get_latent(obs[0].unsqueeze(0), actions[0].unsqueeze(0).unsqueeze(0), None)[1]).rsample().reshape(-1, 32 * 32).unsqueeze(0) + state = self.world_model.get_latent(obs[0], actions[0].unsqueeze(0).unsqueeze(0), None) for idx, (o, a) in enumerate(list(zip(obs, actions))): if idx >= update_num: break - determ, stoch_logits = self.world_model.get_latent(o.unsqueeze(0), a.unsqueeze(0).unsqueeze(0), state) - z_0 = Dist(stoch_logits).rsample().reshape(-1, 32 * 32).unsqueeze(0) - state = (determ, z_0) - inp = torch.concat([determ, z_0], dim=-1) - video_r = self.world_model.image_predictor(inp).mode.cpu().detach().numpy() - rews.append(self.world_model.reward_predictor(inp).mode.item()) + state = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) + video_r = self.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() + rews.append(self.world_model.reward_predictor(state.combined).mode.item()) video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) video.append(video_r) if update_num < len(obs): - zs, _, _, rews, _, determs, _ = self.imagine_trajectory(z_0.squeeze(0), actions[update_num+1:], horizon=self.imagination_horizon - 1 - update_num) - video_r = self.world_model.image_predictor(torch.concat([torch.concat([torch.zeros_like(determs[0]).unsqueeze(0), determs]), torch.concat([z_0, zs])], dim=-1)).mode.cpu().detach().numpy() + states, _, rews, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) + states = State.stack([state, states]) + video_r = self.world_model.image_predictor(states.combined).mode.cpu().detach().numpy() video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) video.append(video_r) @@ -649,7 +637,7 @@ def viz_log(self, rollout, logger, epoch_num): ]) videos_r = np.concatenate(videos_r, axis=3) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r)], axis=2), 0) + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)], axis=2), 0) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) @@ -694,11 +682,11 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=True): - losses, discovered_latents, wm_metrics = self.world_model.calculate_loss( + losses, discovered_states, wm_metrics = self.world_model.calculate_loss( obs, a, r, discount_factors, first_flags) # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device - # world_model_loss = torch.Tensor(1).to(self.device) + world_model_loss = torch.Tensor(0).to(self.device) world_model_loss = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + losses['loss_kl_reg'] + @@ -719,17 +707,14 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation metrics |= wm_metrics with torch.cuda.amp.autocast(enabled=True): - # idx = torch.randperm(discovered_latents.size(0)) - # initial_states = discovered_latents[idx] - # Dreamer does not shuffle - initial_states = discovered_latents - losses_ac = defaultdict( lambda: torch.zeros(1).to(self.device)) + initial_states = State(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), + discovered_states.stoch_logits.flatten(0, 1).unsqueeze(0).detach(), + discovered_states.stoch_.flatten(0, 1).unsqueeze(0).detach()) - zs, _, next_zs, rewards, discount_factors, determs, actions = self.imagine_trajectory( - initial_states) - zs = torch.cat([determs.squeeze(), zs], dim=-1) + states, actions, rewards, discount_factors = self.imagine_trajectory(initial_states) + zs = states.combined rewards = self.world_model.reward_normalizer(rewards) # Discount prediction is disabled for dmc vision in Dreamer @@ -738,14 +723,14 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # Discounted factors should be shifted as they predict whether next state is terminal # First discount factor on contrary is always 1 as it cannot lead to trajectory finish - discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-2]], dim=0).detach() + discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-1]], dim=0).detach() vs = self.critic.lambda_return(zs, rewards[:-1], discount_factors) # Ignore all factors after first is_finished state discount_factors = torch.cumprod(discount_factors, dim=0) - predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) + predicted_vs_dist = self.critic.estimate_value(zs.detach()) losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach())).mean() metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mode.mean() @@ -754,9 +739,9 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer - action_dists = self.actor(zs[1:-1].detach()) + action_dists = self.actor(zs[1:].detach()) baseline = self.critic.target_critic(zs[1:-1]).mode - advantage = (vs[1:] - baseline).detach() + advantage = (vs[1:-1] - baseline).detach() losses_ac['loss_actor_reinforce'] += 0# -(self.rho * action_dists.base_dist.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-1])).mean() diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 8490d1b..538fbb8 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -8,6 +8,7 @@ import torch from torch.profiler import profile, record_function, ProfilerActivity +import lovely_tensors as lt from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.agents.explorative_agent import ExplorativeAgent @@ -33,11 +34,14 @@ def add_image(*args, **kwargs): @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): + lt.monkey_patch() # print(OmegaConf.to_yaml(cfg)) torch.distributions.Distribution.set_default_validate_args(False) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(cfg.seed) random.seed(cfg.seed) torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) From 0dabd542f2d4430deceee387bf6e3b49d59b6507 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 28 Jan 2023 19:35:29 +0000 Subject: [PATCH 036/106] Fixed video log --- rl_sandbox/agents/dreamer_v2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 242f343..d684e4a 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -609,14 +609,14 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ state = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) video_r = self.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() rews.append(self.world_model.reward_predictor(state.combined).mode.item()) - video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) + video_r = (video_r + 0.5) video.append(video_r) if update_num < len(obs): states, _, rews, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) states = State.stack([state, states]) video_r = self.world_model.image_predictor(states.combined).mode.cpu().detach().numpy() - video_r = ((video_r + 0.5) * 255.0).astype(np.uint8) + video_r = (video_r + 0.5) video.append(video_r) return np.concatenate(video), sum(rews) @@ -627,7 +627,7 @@ def viz_log(self, rollout, logger, epoch_num): videos = np.concatenate([ rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( 0, 3, 1, 2) for init_idx in init_indeces - ], axis=3) + ], axis=3).astype(np.float32) / 255.0 real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon].sum() for idx in init_indeces] @@ -638,6 +638,7 @@ def viz_log(self, rollout, logger, epoch_num): videos_r = np.concatenate(videos_r, axis=3) videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)], axis=2), 0) + videos_comparison = (videos_comparison * 255.0).astype(np.uint8) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) From d8ba8c097458890a01a7f09bf4ac11d612875df0 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 29 Jan 2023 09:53:34 +0000 Subject: [PATCH 037/106] Fixes --- rl_sandbox/agents/dreamer_v2.py | 6 ++++-- rl_sandbox/utils/rollout_generation.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index d684e4a..368a65e 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -541,8 +541,11 @@ def imagine_trajectory( torch.Tensor]: if horizon is None: horizon = self.imagination_horizon - states, actions, rewards, ts = [], [], [], [] + prev_state = init_state + prev_action = torch.zeros_like(self.actor(prev_state.combined.detach()).mode) + states, actions, rewards, ts = [init_state], [prev_action], [torch.Tensor(0, device=prev_action.device)], [torch.Tensor(1, device=prev_action.device)] + for i in range(horizon): if precomp_actions is not None: a = precomp_actions[i].unsqueeze(0) @@ -614,7 +617,6 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ if update_num < len(obs): states, _, rews, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) - states = State.stack([state, states]) video_r = self.world_model.image_predictor(states.combined).mode.cpu().detach().numpy() video_r = (video_r + 0.5) video.append(video_r) diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index be9eb3f..ded11d6 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -55,6 +55,8 @@ def iter_rollout( agent.reset() prev_action = np.zeros_like(agent.get_action(state)) + prev_reward = 0 + prev_terminated = False while not terminated: action = agent.get_action(state) @@ -63,9 +65,11 @@ def iter_rollout( # FIXME: will break for non-DM obs = env.render() if collect_obs else None # if collect_obs and isinstance(env, dmEnv): - yield state, prev_action, reward, new_state, terminated, obs + yield state, prev_action, prev_reward, new_state, prev_terminated, obs state = new_state prev_action = action + prev_reward = reward + prev_terminated = terminated def collect_rollout(env: Env, From 1164ba6faa57d0e5b661f45154d91937afff0949 Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 1 Feb 2023 11:05:15 +0000 Subject: [PATCH 038/106] Fixed differences with original DreamerV2 - Fixed visualization - Added LayerNorm inside GRU - Fixed amount of layers in MLP - Fixed state indeces in actor/critic loss Conflicts: rl_sandbox/agents/dreamer_v2.py rl_sandbox/train.py --- rl_sandbox/agents/dreamer_v2.py | 105 +++++++++++++++++-------- rl_sandbox/config/env/dm_cartpole.yaml | 2 +- rl_sandbox/train.py | 18 +++-- rl_sandbox/utils/env.py | 2 +- rl_sandbox/utils/fc_nn.py | 2 +- 5 files changed, 86 insertions(+), 43 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 368a65e..aac20c4 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -127,6 +127,35 @@ def stack(cls, states: list['State'], dim = 0): torch.cat([state.stoch_logits for state in states], dim=dim), stochs) +class GRUCell(nn.Module): + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=norm is not None, **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + class RSSM(nn.Module): """ Recurrent State Space Model @@ -164,8 +193,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): # nn.LayerNorm(hidden_size), nn.ELU(inplace=True) ) - self.determ_recurrent = nn.GRU(input_size=hidden_size, - hidden_size=hidden_size) # Dreamer gru '_cell' + self.determ_recurrent = GRUCell(input_size=hidden_size, hidden_size=hidden_size, norm=True) # Dreamer gru '_cell' # Calculate stochastic state from prior embed # shared between all ensemble models @@ -318,13 +346,13 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, - num_layers=4, + num_layers=5, intermediate_activation=nn.ELU, final_activation=DistLayer('mse')) self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, - num_layers=4, + num_layers=5, intermediate_activation=nn.ELU, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) @@ -402,11 +430,13 @@ def KL(dist1, dist2, free_nat = True): prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits - losses['loss_reconstruction'] = -x_r.log_prob(obs).mean() - losses['loss_reward_pred'] = -r_pred.log_prob(r_c).mean() - losses['loss_discount_pred'] = -f_pred.log_prob(d_c).mean() + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() @@ -428,13 +458,13 @@ def __init__(self, discount_factor: float, update_interval: int, self.critic = fc_nn_generator(latent_dim, 1, 400, - 4, + 5, intermediate_activation=nn.ELU, final_activation=DistLayer('mse')) self.target_critic = fc_nn_generator(latent_dim, 1, 400, - 4, + 5, intermediate_activation=nn.ELU, final_activation=DistLayer('mse')) self.target_critic.requires_grad_(False) @@ -462,7 +492,8 @@ def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): self.lambda_ * v_lambdas[-1]) v_lambdas.append(v_lambda) - return torch.stack(list(reversed(v_lambdas))) + # FIXME: it copies array, so it is quite slow + return torch.stack(v_lambdas).flip(dims=(0,))[:-1] def lambda_return(self, zs, rs, ds): vs = self.target_critic(zs).mode @@ -493,8 +524,10 @@ def __init__( world_model_lr: float, actor_lr: float, critic_lr: float, - device_type: str = 'cpu'): + device_type: str = 'cpu', + logger = None): + self.logger = logger self.device = device_type self.imagination_horizon = imagination_horizon self.cluster_size = batch_cluster_size @@ -511,7 +544,7 @@ def __init__( self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num * 2 if action_type == 'continuous' else actions_num, 400, - 4, + 5, intermediate_activation=nn.ELU, final_activation=DistLayer('normal_trunc' if action_type == 'continuous' else 'onehot')).to(device_type) @@ -543,8 +576,11 @@ def imagine_trajectory( horizon = self.imagination_horizon prev_state = init_state - prev_action = torch.zeros_like(self.actor(prev_state.combined.detach()).mode) - states, actions, rewards, ts = [init_state], [prev_action], [torch.Tensor(0, device=prev_action.device)], [torch.Tensor(1, device=prev_action.device)] + prev_action = torch.zeros_like(self.actor(prev_state.combined.detach()).mean) + states, actions, rewards, ts = ([init_state], + [prev_action], + [self.world_model.reward_predictor(init_state.combined).mode], + [torch.ones(prev_action.shape[:-1] + (1,), device=prev_action.device)]) for i in range(horizon): if precomp_actions is not None: @@ -605,9 +641,9 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ video = [] rews = [] - state = self.world_model.get_latent(obs[0], actions[0].unsqueeze(0).unsqueeze(0), None) + state = None for idx, (o, a) in enumerate(list(zip(obs, actions))): - if idx >= update_num: + if idx > update_num: break state = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) video_r = self.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() @@ -615,23 +651,26 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ video_r = (video_r + 0.5) video.append(video_r) + rews = torch.Tensor(rews).to(obs.device) + if update_num < len(obs): - states, _, rews, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) - video_r = self.world_model.image_predictor(states.combined).mode.cpu().detach().numpy() + states, _, rews_2, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + video_r = self.world_model.image_predictor(states.combined[1:]).mode.cpu().detach().numpy() video_r = (video_r + 0.5) video.append(video_r) - return np.concatenate(video), sum(rews) + return np.concatenate(video), rews def viz_log(self, rollout, logger, epoch_num): - init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 3) + init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 5) videos = np.concatenate([ rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( 0, 3, 1, 2) for init_idx in init_indeces ], axis=3).astype(np.float32) / 255.0 - real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon].sum() for idx in init_indeces] + real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon] for idx in init_indeces] videos_r, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.imagination_horizon//3) for obs_0, a_0 in zip( [rollout.next_states[idx:idx+ self.imagination_horizon] for idx in init_indeces], @@ -639,7 +678,7 @@ def viz_log(self, rollout, logger, epoch_num): ]) videos_r = np.concatenate(videos_r, axis=3) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)], axis=2), 0) + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) videos_comparison = (videos_comparison * 255.0).astype(np.uint8) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) @@ -656,12 +695,12 @@ def viz_log(self, rollout, logger, epoch_num): pass logger.add_image('val/latent_probs', latent_hist, epoch_num, dataformats='HW') logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') - logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num, fps=20) - rewards_err = torch.Tensor([torch.abs(imagined_rewards[i] - real_rewards[i]) for i in range(len(imagined_rewards))]).mean() + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) - logger.add_scalar(f'val/reward', real_rewards[0], epoch_num) + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) def from_np(self, arr: np.ndarray): arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr @@ -724,7 +763,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # as trajectory will not abruptly stop discount_factors = self.critic.gamma * torch.ones_like(rewards) - # Discounted factors should be shifted as they predict whether next state is terminal + # Discounted factors should be shifted as they predict whether next state cannot be used # First discount factor on contrary is always 1 as it cannot lead to trajectory finish discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-1]], dim=0).detach() @@ -733,8 +772,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # Ignore all factors after first is_finished state discount_factors = torch.cumprod(discount_factors, dim=0) - predicted_vs_dist = self.critic.estimate_value(zs.detach()) - losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach())).mean() + predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) + losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(2)*discount_factors[:-1]).mean() metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mode.mean() metrics['critic/avg_lambda_value'] = vs.mean() @@ -742,17 +781,17 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer - action_dists = self.actor(zs[1:].detach()) - baseline = self.critic.target_critic(zs[1:-1]).mode - advantage = (vs[1:-1] - baseline).detach() + action_dists = self.actor(zs[:-2].detach()) + baseline = self.critic.target_critic(zs[:-2]).mode + advantage = (vs[1:] - baseline).detach() losses_ac['loss_actor_reinforce'] += 0# -(self.rho * action_dists.base_dist.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() - losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-1])).mean() + losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-2])).mean() def calculate_entropy(dist): return dist.entropy().unsqueeze(2) # return dist.base_dist.base_dist.entropy().unsqueeze(2) - losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)*discount_factors[:-1]).mean() + losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)*discount_factors[:-2]).mean() losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used diff --git a/rl_sandbox/config/env/dm_cartpole.yaml b/rl_sandbox/config/env/dm_cartpole.yaml index bf2bae9..4bd3888 100644 --- a/rl_sandbox/config/env/dm_cartpole.yaml +++ b/rl_sandbox/config/env/dm_cartpole.yaml @@ -3,7 +3,7 @@ domain_name: cartpole task_name: swingup run_on_pixels: true obs_res: [64, 64] -camera_id: -1 +camera_id: 0 repeat_action_num: 2 transforms: - _target_: rl_sandbox.utils.env.ActionNormalizer diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 538fbb8..37d2370 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -31,7 +31,6 @@ def add_video(*args, **kwargs): def add_image(*args, **kwargs): pass - @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): lt.monkey_patch() @@ -55,6 +54,8 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) + writer = SummaryWriter() + agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape[0], # FIXME: feels bad @@ -62,9 +63,8 @@ def main(cfg: DictConfig): # FIXME: currently only continuous tasks actions_num=env.action_space.shape[0], action_type='continuous', - device_type=cfg.device_type) - - writer = SummaryWriter() + device_type=cfg.device_type, + logger=writer) prof = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], on_trace_ready=torch.profiler.tensorboard_trace_handler('runs/profile_dreamer'), @@ -99,6 +99,7 @@ def main(cfg: DictConfig): agent.viz_log(rollout, writer, -st + i/log_every_n) global_step = 0 + prev_global_step = 0 pbar = tqdm(total=cfg.training.steps, desc='Training') while global_step < cfg.training.steps: ### Training and exploration @@ -123,9 +124,9 @@ def main(cfg: DictConfig): global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) - # FIXME: Currently works only val_logs_every is multiplier of amount of steps per rollout + # FIXME: find more appealing solution ### Validation - if global_step % cfg.training.val_logs_every == 0: + if (global_step % cfg.training.val_logs_every) < (prev_global_step % cfg.training.val_logs_every): with torch.no_grad(): rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) # TODO: make logs visualization in separate process @@ -144,8 +145,11 @@ def main(cfg: DictConfig): agent.viz_log(rollout, writer, global_step) ### Checkpoint - if global_step % cfg.training.save_checkpoint_every == 0: + if (global_step % cfg.training.save_checkpoint_every) < (prev_global_step % cfg.training.save_checkpoint_every): agent.save_ckpt(global_step, losses) + + prev_global_step = global_step + if cfg.debug.profiler: prof.stop() diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 58e08bb..0ac48ce 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -209,9 +209,9 @@ def _step(self, action: Action, repeat_num: int) -> EnvStepResult: rew = 0 for _ in range(repeat_num - 1): ts = self.env.step(action) - rew += ts.reward or 0.0 if ts.last(): break + rew += ts.reward or 0.0 if repeat_num == 1 or not ts.last(): env_res = self._uncode_ts(self.env.step(action)) else: diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index 8473623..6f0fc6c 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -10,7 +10,7 @@ def fc_nn_generator(input_num: int, assert num_layers >= 3 layers = [] layers.append(nn.Linear(input_num, hidden_size)) - layers.append(nn.ReLU(inplace=True)) + layers.append(intermediate_activation(inplace=True)) for _ in range(num_layers - 2): layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(intermediate_activation(inplace=True)) From a6a4e448f0d86952703f5184e24f9e96fa286443 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 3 Feb 2023 08:33:54 +0000 Subject: [PATCH 039/106] Fixed critical bug with nullified action in training --- rl_sandbox/agents/dreamer_v2.py | 4 ++-- rl_sandbox/config/config.yaml | 2 +- rl_sandbox/train.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index aac20c4..a059e96 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -710,9 +710,9 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation is_finished: TerminationFlags, is_first: IsFirstFlags): obs = self.preprocess_obs(self.from_np(obs)) - a = self.from_np(a).to(torch.int64) + a = self.from_np(a) if False: - a = F.one_hot(a, num_classes=self.actions_num).squeeze() + a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index b9b11c8..8b1776e 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -13,7 +13,7 @@ training: steps: 1e6 prefill: 1000 pretrain: 100 - batch_size: 2500 + batch_size: 800 gradient_steps_per_step: 5 save_checkpoint_every: 1e5 val_logs_every: 2.5e3 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 37d2370..d8dffb5 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -94,7 +94,7 @@ def main(cfg: DictConfig): for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, -st + i/log_every_n) + writer.add_video('val/visualization', video, -st + i/log_every_n, fps=20) # FIXME: Very bad from architecture point agent.viz_log(rollout, writer, -st + i/log_every_n) From b689a22d5a83235b575dda58d54ba5a243ae0e91 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 17 Feb 2023 19:34:49 +0000 Subject: [PATCH 040/106] added reloading from checkpoint added reloading from checkpoint --- .gitignore | 3 + pyproject.toml | 2 + rl_sandbox/agents/dreamer_v2.py | 15 ++++- rl_sandbox/agents/explorative_agent.py | 4 ++ rl_sandbox/agents/random_agent.py | 4 ++ rl_sandbox/agents/rl_agent.py | 5 ++ rl_sandbox/config/config.yaml | 2 + rl_sandbox/train.py | 83 +++++++++++--------------- 8 files changed, 69 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index c18dd8d..7a5b954 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ __pycache__/ +.vscode/ +runs/ +poetry.lock diff --git a/pyproject.toml b/pyproject.toml index 60c8601..094fee9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ version = "0.1.0" description = 'Sandbox for my RL experiments' authors = ['Roman Milishchuk '] packages = [{include = 'rl_sandbox'}] +# add config directory as package data [tool.poetry.dependencies] python = "^3.10" @@ -26,6 +27,7 @@ matplotlib = "^3.0.0" webdataset = "^0.2.20" jaxtyping = '^0.2.0' lovely_tensors = '^0.1.10' +torchshow = '^0.5.0' [tool.yapf] based_on_style = "pep8" diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index a059e96..a7a2013 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,6 +1,7 @@ import typing as t from collections import defaultdict from dataclasses import dataclass +from pathlib import Path import matplotlib.pyplot as plt import numpy as np @@ -832,11 +833,21 @@ def save_ckpt(self, epoch_num: int, losses: dict[str, float]): { 'epoch': epoch_num, 'world_model_state_dict': self.world_model.state_dict(), - 'world_model_optimizer_state_dict': - self.world_model_optimizer.state_dict(), + 'world_model_optimizer_state_dict':self.world_model_optimizer.state_dict(), 'actor_state_dict': self.actor.state_dict(), 'critic_state_dict': self.critic.state_dict(), 'actor_optimizer_state_dict': self.actor_optimizer.state_dict(), 'critic_optimizer_state_dict': self.critic_optimizer.state_dict(), 'losses': losses }, f'dreamerV2-{epoch_num}-{sum(losses.values())}.ckpt') + + def load_ckpt(self, ckpt_path: Path): + ckpt = torch.load(ckpt_path) + self.world_model.load_state_dict(ckpt['world_model_state_dict']) + self.world_model_optimizer.load_state_dict( + ckpt['world_model_optimizer_state_dict']) + self.actor.load_state_dict(ckpt['actor_state_dict']) + self.critic.load_state_dict(ckpt['critic_state_dict']) + self.actor_optimizer.load_state_dict(ckpt['actor_optimizer_state_dict']) + self.critic_optimizer.load_state_dict(ckpt['critic_optimizer_state_dict']) + return ckpt['epoch'] diff --git a/rl_sandbox/agents/explorative_agent.py b/rl_sandbox/agents/explorative_agent.py index c26ebca..444ca70 100644 --- a/rl_sandbox/agents/explorative_agent.py +++ b/rl_sandbox/agents/explorative_agent.py @@ -1,5 +1,6 @@ import numpy as np from nptyping import Float, NDArray, Shape +from pathlib import Path from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.schedulers import Scheduler @@ -26,3 +27,6 @@ def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: Te def save_ckpt(self, epoch_num: int, losses: dict[str, float]): self.policy_ag.save_ckpt(epoch_num, losses) self.expl_ag.save_ckpt(epoch_num, losses) + + def load_ckpt(self, ckpt_path: Path): + pass diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py index 1638a93..ea5da16 100644 --- a/rl_sandbox/agents/random_agent.py +++ b/rl_sandbox/agents/random_agent.py @@ -1,5 +1,6 @@ import numpy as np from nptyping import Float, NDArray, Shape +from pathlib import Path from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.env import Env @@ -19,3 +20,6 @@ def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: Te def save_ckpt(self, epoch_num: int, losses: dict[str, float]): pass + + def load_ckpt(self, ckpt_path: Path): + pass diff --git a/rl_sandbox/agents/rl_agent.py b/rl_sandbox/agents/rl_agent.py index 0080fb5..357ec82 100644 --- a/rl_sandbox/agents/rl_agent.py +++ b/rl_sandbox/agents/rl_agent.py @@ -1,5 +1,6 @@ from typing import Any from abc import ABCMeta, abstractmethod +from pathlib import Path from rl_sandbox.utils.replay_buffer import Action, State, States, Actions, Rewards, TerminationFlags @@ -23,3 +24,7 @@ def reset(self): @abstractmethod def save_ckpt(self, epoch_num: int, losses: dict[str, float]): pass + + @abstractmethod + def load_ckpt(self, ckpt_path: Path): + pass diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 8b1776e..5111091 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -8,8 +8,10 @@ defaults: seed: 42 device_type: cuda +log_message: training: + checkpoint_path: null steps: 1e6 prefill: 1000 pretrain: 100 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index d8dffb5..e2e62ad 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -1,24 +1,19 @@ import hydra import numpy as np -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm -from pathlib import Path import random import torch -from torch.profiler import profile, record_function, ProfilerActivity +from torch.profiler import profile, ProfilerActivity import lovely_tensors as lt -from rl_sandbox.agents.random_agent import RandomAgent -from rl_sandbox.agents.explorative_agent import ExplorativeAgent from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env from rl_sandbox.utils.replay_buffer import ReplayBuffer -from rl_sandbox.utils.persistent_replay_buffer import PersistentReplayBuffer -from rl_sandbox.utils.rollout_generation import (collect_rollout, collect_rollout_num, iter_rollout, +from rl_sandbox.utils.rollout_generation import (collect_rollout_num, iter_rollout, fillup_replay_buffer) -from rl_sandbox.utils.schedulers import LinearScheduler class SummaryWriterMock(): @@ -31,6 +26,27 @@ def add_video(*args, **kwargs): def add_image(*args, **kwargs): pass + +def val_logs(agent, val_cfg: DictConfig, env, global_step, writer): + with torch.no_grad(): + rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent) + # TODO: make logs visualization in separate process + # Possibly make the data loader + metrics = MetricsEvaluator().calculate_metrics(rollouts) + for metric_name, metric in metrics.items(): + writer.add_scalar(f'val/{metric_name}', metric, global_step) + + if val_cfg.visualize: + rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) + + for rollout in rollouts: + video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) + writer.add_video('val/visualization', video, global_step, fps=20) + # FIXME: Very bad from architecture point + with torch.no_grad(): + agent.viz_log(rollout, writer, global_step) + + @hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): lt.monkey_patch() @@ -50,11 +66,9 @@ def main(cfg: DictConfig): buff = ReplayBuffer() fillup_replay_buffer(env, buff, max(cfg.training.prefill, cfg.training.batch_size)) - metrics_evaluator = MetricsEvaluator() - # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - writer = SummaryWriter() + writer = SummaryWriter(comment=cfg.log_message or "") agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape[0], @@ -72,34 +86,25 @@ def main(cfg: DictConfig): with_stack=True) if cfg.debug.profiler else None for i in tqdm(range(int(cfg.training.pretrain)), desc='Pretraining'): + if cfg.training.checkpoint_path is not None: + break s, a, r, n, f, first = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get('batch_cluster_size', 1)) losses = agent.train(s, a, r, n, f, first) for loss_name, loss in losses.items(): writer.add_scalar(f'pre_train/{loss_name}', loss, i) + # TODO: remove constants log_every_n = 25 st = int(cfg.training.pretrain) // log_every_n - # FIXME: extract logging to seperate entity to omit - # copy-paste if i % log_every_n == 0: - rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) - # TODO: make logs visualization in separate process - metrics = metrics_evaluator.calculate_metrics(rollouts) - for metric_name, metric in metrics.items(): - writer.add_scalar(f'val/{metric_name}', metric, -st + i/log_every_n) - - if cfg.validation.visualize: - rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) - - for rollout in rollouts: - video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, -st + i/log_every_n, fps=20) - # FIXME: Very bad from architecture point - agent.viz_log(rollout, writer, -st + i/log_every_n) - - global_step = 0 - prev_global_step = 0 + val_logs(agent, cfg.validation, env, -st + i/log_every_n, writer) + + if cfg.training.checkpoint_path is not None: + prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) + else: + prev_global_step = global_step = 0 + pbar = tqdm(total=cfg.training.steps, desc='Training') while global_step < cfg.training.steps: ### Training and exploration @@ -117,7 +122,6 @@ def main(cfg: DictConfig): losses = agent.train(s, a, r, n, f, first) if cfg.debug.profiler: prof.step() - # NOTE: Do not forget to run test with every step to check for outliers if global_step % 10 == 0: for loss_name, loss in losses.items(): writer.add_scalar(f'train/{loss_name}', loss, global_step) @@ -127,22 +131,7 @@ def main(cfg: DictConfig): # FIXME: find more appealing solution ### Validation if (global_step % cfg.training.val_logs_every) < (prev_global_step % cfg.training.val_logs_every): - with torch.no_grad(): - rollouts = collect_rollout_num(env, cfg.validation.rollout_num, agent) - # TODO: make logs visualization in separate process - metrics = metrics_evaluator.calculate_metrics(rollouts) - for metric_name, metric in metrics.items(): - writer.add_scalar(f'val/{metric_name}', metric, global_step) - - if cfg.validation.visualize: - rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) - - for rollout in rollouts: - video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, global_step) - # FIXME: Very bad from architecture point - with torch.no_grad(): - agent.viz_log(rollout, writer, global_step) + val_logs(agent, cfg.validation, env, global_step, writer) ### Checkpoint if (global_step % cfg.training.save_checkpoint_every) < (prev_global_step % cfg.training.save_checkpoint_every): From 095801552caeca2517177d4ac82534d7c9e211ab Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 18 Feb 2023 12:03:24 +0000 Subject: [PATCH 041/106] Added VQ-VAE as a step after each RNN step --- rl_sandbox/agents/dreamer_v2.py | 101 ++++++++++++++++++++++-- rl_sandbox/config/agent/dreamer_v2.yaml | 2 +- 2 files changed, 94 insertions(+), 9 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index a7a2013..2b9f4ec 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -8,7 +8,7 @@ import torch import torch.distributions as td from torch import nn -from torch.nn import functional as F +from torch.nn import ELU, functional as F from jaxtyping import Float, Bool from rl_sandbox.agents.rl_agent import RlAgent @@ -157,6 +157,85 @@ def forward(self, x, h): output = update * cand + (1 - update) * state return output, output + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input): + input = input.reshape(-1, 1, self.n_embed, self.dim) + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + # dist_fn.all_reduce(embed_onehot_sum) + # dist_fn.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_( + embed_onehot_sum, alpha=1 - self.decay + ) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +# class MlpVAE(nn.Module): +# def __init__(self, input_size, latent_size=(8, 8)): +# self.encoder = fc_nn_generator(input_size, +# np.prod(latent_size), +# hidden_size=np.prod(latent_size), +# num_layers=2, +# intermediate_activation=nn.ELU, +# final_activation=nn.Sequential( +# View((-1,) + latent_size), +# DistLayer('onehot') +# )) +# self.decoder = fc_nn_generator(np.prod(latent_size), +# input_size, +# hidden_size=np.prod(latent_size), +# num_layers=2, +# intermediate_activation=nn.ELU, +# final_activation=nn.Identity()) + +# def forward(self, x): +# z = self.encoder(x) +# x_hat = self.decoder(z.rsample()) +# return z, x_hat + + class RSSM(nn.Module): """ Recurrent State Space Model @@ -218,6 +297,8 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) + # self.determ_discretizer = MlpVAE(self.hidden_size) + self.determ_discretizer = Quantize(16, 16) def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] @@ -232,11 +313,13 @@ def predict_next(self, action) -> State: x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action], dim=-1)) # NOTE: x and determ are actually the same value if sequence of 1 is inserted - x, determ = self.determ_recurrent(x, prev_state.determ) + x, determ_prior = self.determ_recurrent(x, prev_state.determ) + determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + determ_post = determ_post.reshape(determ_prior.shape) # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(x) - return State(determ, predicted_stoch_logits) + return State(determ_post, predicted_stoch_logits), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1))) @@ -249,10 +332,10 @@ def forward(self, h_prev: State, embed, 'a' <- action taken on prev step Returns 'h_next' <- the next next of the world """ - prior = self.predict_next(h_prev, action) + prior, diff = self.predict_next(h_prev, action) posterior = self.update_current(prior, embed) - return prior, posterior + return prior, posterior, diff class Encoder(nn.Module): @@ -365,7 +448,7 @@ def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) def predict_next(self, prev_state: State, action): - prior = self.recurrent_model.predict_next(prev_state, action) + prior, _ = self.recurrent_model.predict_next(prev_state, action) reward = self.reward_predictor(prior.combined).mode discount_factors = self.discount_predictor(prior.combined).sample() @@ -375,7 +458,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> Sta if state is None: state = self.get_initial_state() embed = self.encoder(obs.unsqueeze(0)) - _, posterior = self.recurrent_model.forward(state, embed.unsqueeze(0), + _, posterior, _ = self.recurrent_model.forward(state, embed.unsqueeze(0), action) return posterior @@ -415,12 +498,14 @@ def KL(dist1, dist2, free_nat = True): embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) a_t = a_t * (1 - first_t) - prior, posterior = self.recurrent_model.forward(prev_state, embed_t, a_t) + prior, posterior, diff = self.recurrent_model.forward(prev_state, embed_t, a_t) prev_state = posterior priors.append(prior) posteriors.append(posterior) + losses['loss_determ_recons'] += diff + posterior = State.stack(posteriors) prior = State.stack(priors) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 801e643..fa8243f 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -3,7 +3,7 @@ _target_: rl_sandbox.agents.DreamerV2 batch_cluster_size: 50 latent_dim: 32 latent_classes: 32 -rssm_dim: 200 +rssm_dim: 256 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 1.0 From 7ea44bfdd661682f7a9ea093879fe90b8215e746 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 18 Feb 2023 14:22:58 +0000 Subject: [PATCH 042/106] Added parameter for disabling discrete rssm --- rl_sandbox/agents/dreamer_v2.py | 22 ++++++++++++++++------ rl_sandbox/config/agent/dreamer_v2.yaml | 2 ++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 2b9f4ec..e5fc605 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -259,12 +259,13 @@ class RSSM(nn.Module): """ - def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm): super().__init__() self.latent_dim = latent_dim self.latent_classes = latent_classes self.ensemble_num = 1 self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm # Calculate deterministic state from prev stochastic, prev action and prev deterministic self.pre_determ_recurrent = nn.Sequential( @@ -291,6 +292,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): # FIXME: very bad magic number img_sz = 4 * 384 # 384*2x2 self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' # nn.LayerNorm(hidden_size), nn.ELU(inplace=True), @@ -299,6 +301,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes): View((1, -1, latent_dim, self.latent_classes))) # self.determ_discretizer = MlpVAE(self.hidden_size) self.determ_discretizer = Quantize(16, 16) + self.determ_layer_norm = nn.LayerNorm(hidden_size) def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] @@ -314,8 +317,12 @@ def predict_next(self, x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action], dim=-1)) # NOTE: x and determ are actually the same value if sequence of 1 is inserted x, determ_prior = self.determ_recurrent(x, prev_state.determ) - determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) - determ_post = determ_post.reshape(determ_prior.shape) + if self.discrete_rssm: + determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + determ_post = determ_post.reshape(determ_prior.shape) + determ_post = self.determ_layer_norm(determ_post) + else: + determ_post, diff = determ_prior, 0 # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(x) @@ -409,7 +416,7 @@ def update(self, x): class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats): + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm): super().__init__() self.kl_free_nats = kl_free_nats self.kl_beta = kl_loss_scale @@ -424,7 +431,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.recurrent_model = RSSM(latent_dim, rssm_dim, actions_num, - latent_classes=latent_classes) + latent_classes, + discrete_rssm) self.encoder = Encoder() self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes) self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, @@ -610,6 +618,7 @@ def __init__( world_model_lr: float, actor_lr: float, critic_lr: float, + discrete_rssm: bool, device_type: str = 'cpu', logger = None): @@ -625,7 +634,8 @@ def __init__( self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, - kl_loss_balancing, kl_loss_free_nats).to(device_type) + kl_loss_balancing, kl_loss_free_nats, + discrete_rssm).to(device_type) self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num * 2 if action_type == 'continuous' else actions_num, diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index fa8243f..f5c6fc0 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -25,3 +25,5 @@ critic_value_target_lambda: 0.95 critic_update_interval: 100 # [0-1], 1 means hard update critic_soft_update_fraction: 1 + +discrete_rssm: true From d534717a09a9863a94de1723f9c8c4807a1e774b Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 18 Feb 2023 14:25:05 +0000 Subject: [PATCH 043/106] Added Crafter environment --- pyproject.toml | 3 +- rl_sandbox/agents/dreamer_v2.py | 38 ++++++++++++++----------- rl_sandbox/config/agent/dreamer_v2.yaml | 2 +- rl_sandbox/config/env/crafter.yaml | 6 ++++ rl_sandbox/train.py | 11 ++++--- rl_sandbox/utils/env.py | 33 +++++++++++++++++---- rl_sandbox/utils/rollout_generation.py | 9 ++++-- 7 files changed, 69 insertions(+), 33 deletions(-) create mode 100644 rl_sandbox/config/env/crafter.yaml diff --git a/pyproject.toml b/pyproject.toml index 094fee9..9644b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ packages = [{include = 'rl_sandbox'}] python = "^3.10" numpy = '*' nptyping = '*' -gym = "^0.26.1" +gym = "0.25.0" # crafter requires old step api pygame = '*' moviepy = '*' torchvision = '^0.13' @@ -28,6 +28,7 @@ webdataset = "^0.2.20" jaxtyping = '^0.2.0' lovely_tensors = '^0.1.10' torchshow = '^0.5.0' +crafter = '^1.8.0' [tool.yapf] based_on_style = "pep8" diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index e5fc605..8c71f99 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -67,6 +67,7 @@ def forward(self, x): class DistLayer(nn.Module): def __init__(self, type: str): super().__init__() + self._dist = type match type: case 'mse': self.dist = lambda x: td.Normal(x.float(), 1.0) @@ -93,7 +94,11 @@ def get_trunc_normal(x, min_std=0.1): raise RuntimeError("Invalid dist layer") def forward(self, x): - return td.Independent(self.dist(x), 1) + match self._dist: + case 'onehot': + return self.dist(x) + case _: + return td.Independent(self.dist(x), 1) def Dist(val): return DistLayer('onehot')(val) @@ -628,9 +633,10 @@ def __init__( self.cluster_size = batch_cluster_size self.actions_num = actions_num self.rho = actor_reinforce_fraction - # if actor_reinforce_fraction != 0: - # raise NotImplementedError("Reinforce part is not implemented") self.eta = actor_entropy_scale + self.is_discrete = (action_type != 'continuous') + if self.rho is None: + self.rho = self.is_discrete self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, @@ -638,11 +644,11 @@ def __init__( discrete_rssm).to(device_type) self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, - actions_num * 2 if action_type == 'continuous' else actions_num, + actions_num if self.is_discrete else actions_num * 2, 400, 5, intermediate_activation=nn.ELU, - final_activation=DistLayer('normal_trunc' if action_type == 'continuous' else 'onehot')).to(device_type) + final_activation=DistLayer('onehot' if self.is_discrete else 'normal_trunc')).to(device_type) self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, @@ -720,13 +726,13 @@ def get_action(self, obs: Observation) -> Action: actor_dist = self.actor(self._state.combined) self._last_action = actor_dist.sample() - if False: - self._action_probs += actor_dist.base_dist.probs.squeeze() - self._latent_probs += self._state.stoch_dist.base_dist.probs.squeeze() + if self.is_discrete: + self._action_probs += actor_dist.probs.squeeze() + self._latent_probs += self._state.stoch_dist.probs.squeeze() self._stored_steps += 1 - if False: - return np.array([self._last_action.squeeze().detach().cpu().numpy().argmax()]) + if self.is_discrete: + return self._last_action.squeeze().detach().cpu().numpy().argmax() else: return self._last_action.squeeze().detach().cpu().numpy() @@ -734,6 +740,8 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs) actions = self.from_np(actions) + if self.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.actions_num).squeeze() video = [] rews = [] @@ -780,7 +788,7 @@ def viz_log(self, rollout, logger, epoch_num): latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) # if discrete action space - if False: + if self.is_discrete: action_hist = (self._action_probs / self._stored_steps).detach().cpu().numpy() fig = plt.Figure() ax = fig.add_axes([0, 0, 1, 1]) @@ -807,17 +815,13 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation obs = self.preprocess_obs(self.from_np(obs)) a = self.from_np(a) - if False: + if self.is_discrete: a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() r = self.from_np(r) next_obs = self.preprocess_obs(self.from_np(next_obs)) discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) first_flags = self.from_np(is_first).type(torch.float32) - number_of_zero_discounts = (1 - discount_factors).sum() - if number_of_zero_discounts > 0: - pass - # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=True): losses, discovered_states, wm_metrics = self.world_model.calculate_loss( @@ -880,7 +884,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation action_dists = self.actor(zs[:-2].detach()) baseline = self.critic.target_critic(zs[:-2]).mode advantage = (vs[1:] - baseline).detach() - losses_ac['loss_actor_reinforce'] += 0# -(self.rho * action_dists.base_dist.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() + losses_ac['loss_actor_reinforce'] += -(self.rho * action_dists.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-2])).mean() def calculate_entropy(dist): diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index f5c6fc0..dd23671 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -16,7 +16,7 @@ imagination_horizon: 15 actor_lr: 8e-5 # mixing of reinforce and maximizing value func # for dm_control it is zero in Dreamer (Atari 1) -actor_reinforce_fraction: 0.0 +actor_reinforce_fraction: null actor_entropy_scale: 1e-4 critic_lr: 8e-5 diff --git a/rl_sandbox/config/env/crafter.yaml b/rl_sandbox/config/env/crafter.yaml new file mode 100644 index 0000000..822ac27 --- /dev/null +++ b/rl_sandbox/config/env/crafter.yaml @@ -0,0 +1,6 @@ +_target_: rl_sandbox.utils.env.GymEnv +task_name: CrafterReward-v1 +run_on_pixels: false # it is run on pixels by default +obs_res: [64, 64] +repeat_action_num: 1 +transforms: [] diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index e2e62ad..820964c 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -8,6 +8,7 @@ import torch from torch.profiler import profile, ProfilerActivity import lovely_tensors as lt +from gym.spaces import Discrete from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env @@ -70,13 +71,11 @@ def main(cfg: DictConfig): # (Plan2Explore, etc) writer = SummaryWriter(comment=cfg.log_message or "") + is_discrete = isinstance(env.action_space, Discrete) agent = hydra.utils.instantiate(cfg.agent, - obs_space_num=env.observation_space.shape[0], - # FIXME: feels bad - # actions_num=(env.action_space.high - env.action_space.low + 1).item(), - # FIXME: currently only continuous tasks - actions_num=env.action_space.shape[0], - action_type='continuous', + obs_space_num=env.observation_space.shape, + actions_num = env.action_space.n if is_discrete else env.action_space.shape[0], + action_type='discrete' if is_discrete else 'continuous' , device_type=cfg.device_type, logger=writer) diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 0ac48ce..94cf92d 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import gym +import crafter import numpy as np from dm_control import suite from dm_env import Environment as dmEnviron @@ -136,21 +137,41 @@ def __init__(self, task_name: str, run_on_pixels: bool, obs_res: tuple[int, int] repeat_action_num: int, transforms: list[ActionTransformer]): super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) + self.task_name = task_name self.env: gym.Env = gym.make(task_name) - self.visualized_env: gym.Env = gym.make(task_name, render_mode='rgb_array_list') + if self.task_name.startswith("Crafter"): + crafter.Recorder(self.env, + "runs/", + save_stats=True, + save_video=False, + save_episode=False) if run_on_pixels: raise NotImplementedError("Run on pixels supported only for 'dm_control'") - def _step(self, action: Action) -> EnvStepResult: - new_state, reward, terminated, _, _ = self.env.step(action) - return EnvStepResult(new_state, reward, terminated) + def render(self): + raise RuntimeError("Render is not supported for GymEnv") + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + rew = 0 + for _ in range(repeat_num - 1): + new_state, reward, terminated, _ = self.env.step(action) + ts = EnvStepResult(new_state, reward, terminated) + if terminated: + break + rew += reward or 0.0 + if repeat_num == 1 or not terminated: + new_state, reward, terminated, _ = self.env.step(action) + env_res = EnvStepResult(new_state, reward, terminated) + else: + env_res = ts + env_res.reward = rew + (env_res.reward or 0.0) + return env_res def reset(self): - state, _ = self.env.reset() + state = self.env.reset() return EnvStepResult(state, 0, False) - @property def _observation_space(self): return self.env.observation_space diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index ded11d6..3642c44 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -1,5 +1,6 @@ import typing as t from multiprocessing.synchronize import Lock +from IPython.core.inputtransformer2 import warnings import numpy as np import torch.multiprocessing as mp @@ -62,8 +63,12 @@ def iter_rollout( new_state, reward, terminated = unpack(env.step(action)) - # FIXME: will break for non-DM - obs = env.render() if collect_obs else None + try: + obs = env.render() if collect_obs else None + except RuntimeError: + # FIXME: hot-fix for Crafter env to work + warnings.warn("Cannot render environment, using state instead") + obs = state # if collect_obs and isinstance(env, dmEnv): yield state, prev_action, prev_reward, new_state, prev_terminated, obs state = new_state From 74c1a3cc1f7daf83a0d4ccb00222b786cdae3536 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 19 Feb 2023 14:26:16 +0000 Subject: [PATCH 044/106] Added logging of gradients --- rl_sandbox/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 820964c..ff2ae85 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -124,6 +124,9 @@ def main(cfg: DictConfig): if global_step % 10 == 0: for loss_name, loss in losses.items(): writer.add_scalar(f'train/{loss_name}', loss, global_step) + for tag, value in agent.world_model.named_parameters(): + tag = tag.replace('.', '/') + writer.add_histogram(f'train/grad/{tag}', value.grad.data.cpu().numpy(), global_step) global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) From 07ee10f4cfe801f932e440b3d3265648274918f6 Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 20 Feb 2023 23:29:45 +0000 Subject: [PATCH 045/106] Added missing Dreamer parts, required for Crafter - Added replay buffer end prioritization - Added discount prediction - Added separate config for Crafter, so changed parameters - Added gradient logging --- rl_sandbox/agents/dreamer_v2.py | 37 ++++++++++-------- rl_sandbox/config/agent/dreamer_v2.yaml | 3 +- .../config/agent/dreamer_v2_crafter.yaml | 29 ++++++++++++++ rl_sandbox/config/config.yaml | 18 ++++----- rl_sandbox/train.py | 39 ++++++++++++------- rl_sandbox/utils/env.py | 10 +---- rl_sandbox/utils/fc_nn.py | 2 + rl_sandbox/utils/replay_buffer.py | 15 +++++-- 8 files changed, 101 insertions(+), 52 deletions(-) create mode 100644 rl_sandbox/config/agent/dreamer_v2_crafter.yaml diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 8c71f99..9cd769e 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -276,7 +276,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.pre_determ_recurrent = nn.Sequential( nn.Linear(latent_dim * latent_classes + actions_num, hidden_size), # Dreamer 'img_in' - # nn.LayerNorm(hidden_size), + nn.LayerNorm(hidden_size), nn.ELU(inplace=True) ) self.determ_recurrent = GRUCell(input_size=hidden_size, hidden_size=hidden_size, norm=True) # Dreamer gru '_cell' @@ -286,7 +286,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.ensemble_prior_estimator = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' - # nn.LayerNorm(hidden_size), + nn.LayerNorm(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' @@ -299,13 +299,13 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' - # nn.LayerNorm(hidden_size), + nn.LayerNorm(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) # self.determ_discretizer = MlpVAE(self.hidden_size) - self.determ_discretizer = Quantize(16, 16) + self.determ_discretizer = Quantize(32, 32) self.determ_layer_norm = nn.LayerNorm(hidden_size) def estimate_stochastic_latent(self, prev_determ: torch.Tensor): @@ -361,7 +361,7 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) - # layers.append(nn.BatchNorm2d(out_channels)) + layers.append(nn.GroupNorm(1, out_channels)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels layers.append(nn.Flatten()) @@ -387,7 +387,7 @@ def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): out_channels = 3 layers.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) else: - # layers.append(nn.BatchNorm2d(in_channels)) + layers.append(nn.GroupNorm(1, in_channels)) layers.append( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) @@ -421,7 +421,8 @@ def update(self, x): class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm): + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, + predict_discount): super().__init__() self.kl_free_nats = kl_free_nats self.kl_beta = kl_loss_scale @@ -432,6 +433,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.actions_num = actions_num # kl loss balancing (prior/posterior) self.alpha = kl_loss_balancing + self.predict_discount = predict_discount self.recurrent_model = RSSM(latent_dim, rssm_dim, @@ -464,7 +466,10 @@ def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) reward = self.reward_predictor(prior.combined).mode - discount_factors = self.discount_predictor(prior.combined).sample() + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).sample() + else: + discount_factors = torch.ones_like(reward) return prior, reward, discount_factors def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: @@ -621,6 +626,7 @@ def __init__( critic_soft_update_fraction: float, critic_value_target_lambda: float, world_model_lr: float, + world_model_predict_discount: bool, actor_lr: float, critic_lr: float, discrete_rssm: bool, @@ -641,7 +647,8 @@ def __init__( self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_loss_free_nats, - discrete_rssm).to(device_type) + discrete_rssm, + world_model_predict_discount).to(device_type) self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num if self.is_discrete else actions_num * 2, @@ -841,6 +848,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # FIXME: clip gradient should be parametrized self.scaler.unscale_(self.world_model_optimizer) + for tag, value in self.world_model.named_parameters(): + wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) @@ -859,10 +868,6 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation zs = states.combined rewards = self.world_model.reward_normalizer(rewards) - # Discount prediction is disabled for dmc vision in Dreamer - # as trajectory will not abruptly stop - discount_factors = self.critic.gamma * torch.ones_like(rewards) - # Discounted factors should be shifted as they predict whether next state cannot be used # First discount factor on contrary is always 1 as it cannot lead to trajectory finish discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-1]], dim=0).detach() @@ -921,9 +926,9 @@ def calculate_entropy(dist): self.critic.update_target() self.scaler.update() - losses = {l: val.detach().cpu().item() for l, val in losses.items()} - losses_ac = {l: val.detach().cpu().item() for l, val in losses_ac.items()} - metrics = {l: val.detach().cpu().item() for l, val in metrics.items()} + losses = {l: val.detach().cpu().numpy() for l, val in losses.items()} + losses_ac = {l: val.detach().cpu().numpy() for l, val in losses_ac.items()} + metrics = {l: val.detach().cpu().numpy() for l, val in metrics.items()} return losses | losses_ac | metrics diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index dd23671..e9efc94 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -8,6 +8,7 @@ kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 1.0 world_model_lr: 3e-4 +world_model_predict_discount: false # ActorCritic parameters discount_factor: 0.999 @@ -26,4 +27,4 @@ critic_update_interval: 100 # [0-1], 1 means hard update critic_soft_update_fraction: 1 -discrete_rssm: true +discrete_rssm: false diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml new file mode 100644 index 0000000..0cf23d3 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -0,0 +1,29 @@ +_target_: rl_sandbox.agents.DreamerV2 +# World model parameters +batch_cluster_size: 50 +latent_dim: 32 +latent_classes: 32 +rssm_dim: 1024 +kl_loss_scale: 1.0 +kl_loss_balancing: 0.8 +kl_loss_free_nats: 0.0 +world_model_lr: 1e-4 +world_model_predict_discount: true + +# ActorCritic parameters +discount_factor: 0.999 +imagination_horizon: 15 + +actor_lr: 1e-4 +# automatically chooses depending on discrete/continuous env +actor_reinforce_fraction: null +actor_entropy_scale: 3e-3 + +critic_lr: 1e-4 +# Lambda parameter for trainin deeper multi-step prediction +critic_value_target_lambda: 0.95 +critic_update_interval: 100 +# [0-1], 1 means hard update +critic_soft_update_fraction: 1 + +discrete_rssm: true diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 5111091..7ff91a9 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,24 +1,22 @@ defaults: - - agent/dreamer_v2 - #- env/dm_walker - - env/dm_cartpole - #- env/dm_quadruped - #- env/dm_cheetah + - agent: dreamer_v2_crafter + - env: crafter - _self_ seed: 42 device_type: cuda -log_message: +log_message: Activated Discrete RSSM training: checkpoint_path: null steps: 1e6 - prefill: 1000 - pretrain: 100 + prefill: 10000 + pretrain: 1 batch_size: 800 + prioritize_ends: true gradient_steps_per_step: 5 - save_checkpoint_every: 1e5 - val_logs_every: 2.5e3 + save_checkpoint_every: 3e5 + val_logs_every: 5e3 validation: rollout_num: 5 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index ff2ae85..7da1edc 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -9,6 +9,7 @@ from torch.profiler import profile, ProfilerActivity import lovely_tensors as lt from gym.spaces import Discrete +import crafter from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env @@ -62,15 +63,24 @@ def main(cfg: DictConfig): torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) - env: Env = hydra.utils.instantiate(cfg.env) - - buff = ReplayBuffer() - fillup_replay_buffer(env, buff, max(cfg.training.prefill, cfg.training.batch_size)) - # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) writer = SummaryWriter(comment=cfg.log_message or "") + env: Env = hydra.utils.instantiate(cfg.env) + val_env: Env = hydra.utils.instantiate(cfg.env) + # TOOD: Create maybe some additional validation env + if cfg.env.task_name.startswith("Crafter"): + val_env.env = crafter.Recorder(val_env.env, + writer.log_dir, + save_stats=True, + save_video=False, + save_episode=False) + + buff = ReplayBuffer(prioritize_ends=cfg.training.prioritize_ends, + min_ep_len=cfg.agent.get('batch_cluster_size', 1)*(cfg.training.prioritize_ends + 1)) + fillup_replay_buffer(env, buff, max(cfg.training.prefill, cfg.training.batch_size)) + is_discrete = isinstance(env.action_space, Discrete) agent = hydra.utils.instantiate(cfg.agent, obs_space_num=env.observation_space.shape, @@ -91,13 +101,16 @@ def main(cfg: DictConfig): cluster_size=cfg.agent.get('batch_cluster_size', 1)) losses = agent.train(s, a, r, n, f, first) for loss_name, loss in losses.items(): - writer.add_scalar(f'pre_train/{loss_name}', loss, i) + if 'grad' in loss_name: + writer.add_histogram(f'pre_train/{loss_name}', loss, i) + else: + writer.add_scalar(f'pre_train/{loss_name}', loss.item(), i) # TODO: remove constants log_every_n = 25 st = int(cfg.training.pretrain) // log_every_n if i % log_every_n == 0: - val_logs(agent, cfg.validation, env, -st + i/log_every_n, writer) + val_logs(agent, cfg.validation, val_env, -st + i/log_every_n, writer) if cfg.training.checkpoint_path is not None: prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) @@ -121,19 +134,19 @@ def main(cfg: DictConfig): losses = agent.train(s, a, r, n, f, first) if cfg.debug.profiler: prof.step() - if global_step % 10 == 0: + if global_step % 100 == 0: for loss_name, loss in losses.items(): - writer.add_scalar(f'train/{loss_name}', loss, global_step) - for tag, value in agent.world_model.named_parameters(): - tag = tag.replace('.', '/') - writer.add_histogram(f'train/grad/{tag}', value.grad.data.cpu().numpy(), global_step) + if 'grad' in loss_name: + writer.add_histogram(f'train/{loss_name}', loss, global_step) + else: + writer.add_scalar(f'train/{loss_name}', loss.item(), global_step) global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) # FIXME: find more appealing solution ### Validation if (global_step % cfg.training.val_logs_every) < (prev_global_step % cfg.training.val_logs_every): - val_logs(agent, cfg.validation, env, global_step, writer) + val_logs(agent, cfg.validation, val_env, global_step, writer) ### Checkpoint if (global_step % cfg.training.save_checkpoint_every) < (prev_global_step % cfg.training.save_checkpoint_every): diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 94cf92d..dbabd66 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -3,7 +3,6 @@ from dataclasses import dataclass import gym -import crafter import numpy as np from dm_control import suite from dm_env import Environment as dmEnviron @@ -139,12 +138,6 @@ def __init__(self, task_name: str, run_on_pixels: bool, obs_res: tuple[int, int] self.task_name = task_name self.env: gym.Env = gym.make(task_name) - if self.task_name.startswith("Crafter"): - crafter.Recorder(self.env, - "runs/", - save_stats=True, - save_video=False, - save_episode=False) if run_on_pixels: raise NotImplementedError("Run on pixels supported only for 'dm_control'") @@ -165,7 +158,8 @@ def _step(self, action: Action, repeat_num: int) -> EnvStepResult: env_res = EnvStepResult(new_state, reward, terminated) else: env_res = ts - env_res.reward = rew + (env_res.reward or 0.0) + # FIXME: move to config the option + env_res.reward = np.tanh(rew + (env_res.reward or 0.0)) return env_res def reset(self): diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index 6f0fc6c..6200ddf 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -10,9 +10,11 @@ def fc_nn_generator(input_num: int, assert num_layers >= 3 layers = [] layers.append(nn.Linear(input_num, hidden_size)) + layers.append(nn.LayerNorm(hidden_size)) layers.append(intermediate_activation(inplace=True)) for _ in range(num_layers - 2): layers.append(nn.Linear(hidden_size, hidden_size)) + layers.append(nn.LayerNorm(hidden_size)) layers.append(intermediate_activation(inplace=True)) layers.append(nn.Linear(hidden_size, output_num)) layers.append(final_activation) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index a483898..6a2716c 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -32,10 +32,12 @@ def __len__(self): # TODO: make buffer concurrent-friendly class ReplayBuffer: - def __init__(self, max_len=2e6): + def __init__(self, max_len=2e6, prioritize_ends: bool = False, min_ep_len: int = 1): self.rollouts: deque[Rollout] = deque() self.rollouts_len: deque[int] = deque() self.curr_rollout = None + self.min_ep_len = min_ep_len + self.prioritize_ends = prioritize_ends self.max_len = max_len self.total_num = 0 @@ -43,6 +45,8 @@ def __len__(self): return self.total_num def add_rollout(self, rollout: Rollout): + if len(rollout.next_states) <= self.min_ep_len: + return # NOTE: only last next state is stored, all others are induced # from state on next step rollout.next_states = np.expand_dims(rollout.next_states[-1], 0) @@ -86,7 +90,7 @@ def sample( seq_num = batch_size // cluster_size # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me s, a, r, n, t, is_first = [], [], [], [], [], [] - do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > cluster_size + do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > (cluster_size * (self.prioritize_ends + 1)) tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), seq_num, @@ -99,9 +103,12 @@ def sample( # -1 because we don't have next_state on terminal rollout, r_len = self.curr_rollout, len(self.curr_rollout.states) - 1 - # NOTE: maybe just not add such small rollouts to buffer assert r_len > cluster_size - 1, "Rollout it too small" - s_idx = np.random.choice(r_len - cluster_size + 1, 1).item() + max_idx = r_len - cluster_size + 1 + if self.prioritize_ends: + s_idx = np.random.choice(max_idx - cluster_size + 1, 1).item() + cluster_size - 1 + else: + s_idx = np.random.choice(max_idx, 1).item() s_indeces.append(s_idx) if r_idx == len(self.rollouts): From b21467a26f643d64b309f372fc80da9ee5ea2e91 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 21 Feb 2023 00:10:04 +0000 Subject: [PATCH 046/106] Added scheduler --- rl_sandbox/agents/dreamer_v2.py | 5 +++++ rl_sandbox/config/config.yaml | 2 +- rl_sandbox/utils/schedulers.py | 8 ++++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 9cd769e..11ca292 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -17,6 +17,7 @@ Observations, Rewards, TerminationFlags, IsFirstFlags) from rl_sandbox.utils.dists import TruncatedNormal +from rl_sandbox.utils.schedulers import LinearScheduler class View(nn.Module): @@ -306,6 +307,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret View((1, -1, latent_dim, self.latent_classes))) # self.determ_discretizer = MlpVAE(self.hidden_size) self.determ_discretizer = Quantize(32, 32) + self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) self.determ_layer_norm = nn.LayerNorm(hidden_size) def estimate_stochastic_latent(self, prev_determ: torch.Tensor): @@ -326,6 +328,8 @@ def predict_next(self, determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) determ_post = determ_post.reshape(determ_prior.shape) determ_post = self.determ_layer_norm(determ_post) + alpha = self.discretizer_scheduler.val + determ_post = alpha * determ_prior + (1-alpha) * determ_post else: determ_post, diff = determ_prior, 0 @@ -833,6 +837,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation with torch.cuda.amp.autocast(enabled=True): losses, discovered_states, wm_metrics = self.world_model.calculate_loss( obs, a, r, discount_factors, first_flags) + self.world_model.recurrent_model.discretizer_scheduler.step() # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device world_model_loss = torch.Tensor(0).to(self.device) diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 7ff91a9..80cbbae 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -5,7 +5,7 @@ defaults: seed: 42 device_type: cuda -log_message: Activated Discrete RSSM +log_message: Added scheduler for discretizer training: checkpoint_path: null diff --git a/rl_sandbox/utils/schedulers.py b/rl_sandbox/utils/schedulers.py index d49adf2..a68876c 100644 --- a/rl_sandbox/utils/schedulers.py +++ b/rl_sandbox/utils/schedulers.py @@ -13,9 +13,13 @@ def __init__(self, initial_value, final_value, duration): self._dur = duration - 1 self._curr_t = 0 - def step(self) -> float: + @property + def val(self) -> float: if self._curr_t >= self._dur: return self._final - val = np.interp([self._curr_t], [0, self._dur], [self._init, self._final]) + return np.interp([self._curr_t], [0, self._dur], [self._init, self._final]).item() + + def step(self) -> float: + val = self.val self._curr_t += 1 return val From abb8a20f00659b2b4c43073223f390bfaa5f6cc6 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 4 Mar 2023 11:42:54 +0000 Subject: [PATCH 047/106] Fixed loss in quantizer --- rl_sandbox/agents/dreamer_v2.py | 22 ++++++++++++------- rl_sandbox/config/agent/dreamer_v2.yaml | 2 +- .../config/agent/dreamer_v2_crafter.yaml | 2 +- rl_sandbox/config/config.yaml | 12 +++++----- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 11ca292..7be8823 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -174,12 +174,16 @@ def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): self.eps = eps embed = torch.randn(dim, n_embed) + self.inp_in = nn.Linear(1024, self.n_embed*self.dim) + self.inp_out = nn.Linear(self.n_embed*self.dim, 1024) self.register_buffer("embed", embed) self.register_buffer("cluster_size", torch.zeros(n_embed)) self.register_buffer("embed_avg", embed.clone()) - def forward(self, input): - input = input.reshape(-1, 1, self.n_embed, self.dim) + def forward(self, inp): + # input = self.inp_in(inp).reshape(-1, 1, self.n_embed, self.dim) + input = inp.reshape(-1, 1, self.n_embed, self.dim) + inp = input flatten = input.reshape(-1, self.dim) dist = ( flatten.pow(2).sum(1, keepdim=True) @@ -195,9 +199,6 @@ def forward(self, input): embed_onehot_sum = embed_onehot.sum(0) embed_sum = flatten.transpose(0, 1) @ embed_onehot - # dist_fn.all_reduce(embed_onehot_sum) - # dist_fn.all_reduce(embed_sum) - self.cluster_size.data.mul_(self.decay).add_( embed_onehot_sum, alpha=1 - self.decay ) @@ -209,8 +210,10 @@ def forward(self, input): embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) self.embed.data.copy_(embed_normalized) - diff = (quantize.detach() - input).pow(2).mean() - quantize = input + (quantize - input).detach() + # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) + quantize_out = quantize + diff = 0.25*(quantize_out.detach() - inp).pow(2).mean() + (quantize_out - inp.detach()).pow(2).mean() + quantize = inp + (quantize_out - inp).detach() return quantize, diff, embed_ind @@ -306,7 +309,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) # self.determ_discretizer = MlpVAE(self.hidden_size) - self.determ_discretizer = Quantize(32, 32) + self.determ_discretizer = Quantize(16, 16) self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) self.determ_layer_norm = nn.LayerNorm(hidden_size) @@ -925,6 +928,9 @@ def calculate_entropy(dist): nn.utils.clip_grad_norm_(self.actor.parameters(), 100) nn.utils.clip_grad_norm_(self.critic.parameters(), 100) + for tag, value in self.actor.named_parameters(): + wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() + self.scaler.step(self.actor_optimizer) self.scaler.step(self.critic_optimizer) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index e9efc94..4b7898c 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -27,4 +27,4 @@ critic_update_interval: 100 # [0-1], 1 means hard update critic_soft_update_fraction: 1 -discrete_rssm: false +discrete_rssm: true diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 0cf23d3..9c246f6 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -26,4 +26,4 @@ critic_update_interval: 100 # [0-1], 1 means hard update critic_soft_update_fraction: 1 -discrete_rssm: true +discrete_rssm: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 80cbbae..c5798f2 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,21 +1,21 @@ defaults: - - agent: dreamer_v2_crafter - - env: crafter + - agent: dreamer_v2 + - env: dm_quadruped - _self_ seed: 42 device_type: cuda -log_message: Added scheduler for discretizer +log_message: Quadruped with rssm training: checkpoint_path: null steps: 1e6 prefill: 10000 - pretrain: 1 batch_size: 800 - prioritize_ends: true + pretrain: 100 + prioritize_ends: false gradient_steps_per_step: 5 - save_checkpoint_every: 3e5 + save_checkpoint_every: 1e6 val_logs_every: 5e3 validation: From 383ebfa536dbc222ac61a3d33c84ce3265749fb2 Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 9 Mar 2023 15:21:31 +0000 Subject: [PATCH 048/106] Fixes --- rl_sandbox/agents/dreamer_v2.py | 47 ++++++++++++------- rl_sandbox/config/agent/dreamer_v2.yaml | 5 +- .../config/agent/dreamer_v2_crafter.yaml | 1 + rl_sandbox/config/config.yaml | 14 ++++-- rl_sandbox/train.py | 9 ++-- rl_sandbox/utils/env.py | 2 + rl_sandbox/utils/fc_nn.py | 6 ++- 7 files changed, 53 insertions(+), 31 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 7be8823..ed47b3a 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -2,13 +2,14 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path +from functools import partial import matplotlib.pyplot as plt import numpy as np import torch import torch.distributions as td from torch import nn -from torch.nn import ELU, functional as F +from torch.nn import functional as F from jaxtyping import Float, Bool from rl_sandbox.agents.rl_agent import RlAgent @@ -268,7 +269,7 @@ class RSSM(nn.Module): """ - def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm): + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity): super().__init__() self.latent_dim = latent_dim self.latent_classes = latent_classes @@ -280,7 +281,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.pre_determ_recurrent = nn.Sequential( nn.Linear(latent_dim * latent_classes + actions_num, hidden_size), # Dreamer 'img_in' - nn.LayerNorm(hidden_size), + norm_layer(hidden_size), nn.ELU(inplace=True) ) self.determ_recurrent = GRUCell(input_size=hidden_size, hidden_size=hidden_size, norm=True) # Dreamer gru '_cell' @@ -290,7 +291,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.ensemble_prior_estimator = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' - nn.LayerNorm(hidden_size), + norm_layer(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' @@ -303,7 +304,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' - nn.LayerNorm(hidden_size), + norm_layer(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' @@ -359,7 +360,7 @@ def forward(self, h_prev: State, embed, class Encoder(nn.Module): - def __init__(self, kernel_sizes=[4, 4, 4, 4]): + def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4, 4]): super().__init__() layers = [] @@ -368,7 +369,7 @@ def __init__(self, kernel_sizes=[4, 4, 4, 4]): for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) - layers.append(nn.GroupNorm(1, out_channels)) + layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels layers.append(nn.Flatten()) @@ -380,7 +381,7 @@ def forward(self, X): class Decoder(nn.Module): - def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): + def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 6, 6]): super().__init__() layers = [] self.channel_step = 48 @@ -394,7 +395,7 @@ def __init__(self, input_size, kernel_sizes=[5, 5, 6, 6]): out_channels = 3 layers.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) else: - layers.append(nn.GroupNorm(1, in_channels)) + layers.append(norm_layer(1, in_channels)) layers.append( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2)) @@ -429,7 +430,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount): + predict_discount, layer_norm: bool): super().__init__() self.kl_free_nats = kl_free_nats self.kl_beta = kl_loss_scale @@ -446,20 +447,24 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, rssm_dim, actions_num, latent_classes, - discrete_rssm) - self.encoder = Encoder() - self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes) + discrete_rssm, + norm_layer=nn.Identity if layer_norm else nn.LayerNorm) + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm) self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, num_layers=5, intermediate_activation=nn.ELU, + layer_norm=layer_norm, final_activation=DistLayer('mse')) self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, num_layers=5, intermediate_activation=nn.ELU, + layer_norm=layer_norm, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) @@ -558,7 +563,7 @@ def KL(dist1, dist2, free_nat = True): class ImaginativeCritic(nn.Module): def __init__(self, discount_factor: float, update_interval: int, - soft_update_fraction: float, value_target_lambda: float, latent_dim: int): + soft_update_fraction: float, value_target_lambda: float, latent_dim: int, layer_norm: bool): super().__init__() self.gamma = discount_factor self.critic_update_interval = update_interval @@ -571,12 +576,14 @@ def __init__(self, discount_factor: float, update_interval: int, 400, 5, intermediate_activation=nn.ELU, + layer_norm=layer_norm, final_activation=DistLayer('mse')) self.target_critic = fc_nn_generator(latent_dim, 1, 400, 5, intermediate_activation=nn.ELU, + layer_norm=layer_norm, final_activation=DistLayer('mse')) self.target_critic.requires_grad_(False) @@ -637,6 +644,7 @@ def __init__( actor_lr: float, critic_lr: float, discrete_rssm: bool, + layer_norm: bool, device_type: str = 'cpu', logger = None): @@ -655,19 +663,21 @@ def __init__( rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_loss_free_nats, discrete_rssm, - world_model_predict_discount).to(device_type) + world_model_predict_discount, layer_norm).to(device_type) self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num if self.is_discrete else actions_num * 2, 400, 5, + layer_norm=layer_norm, intermediate_activation=nn.ELU, final_activation=DistLayer('onehot' if self.is_discrete else 'normal_trunc')).to(device_type) self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, critic_value_target_lambda, - rssm_dim + latent_dim * latent_classes).to(device_type) + rssm_dim + latent_dim * latent_classes, + layer_norm=layer_norm).to(device_type) self.scaler = torch.cuda.amp.GradScaler() self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), @@ -941,6 +951,7 @@ def calculate_entropy(dist): losses_ac = {l: val.detach().cpu().numpy() for l, val in losses_ac.items()} metrics = {l: val.detach().cpu().numpy() for l, val in metrics.items()} + losses['total'] = sum(losses.values()) return losses | losses_ac | metrics def save_ckpt(self, epoch_num: int, losses: dict[str, float]): @@ -948,13 +959,13 @@ def save_ckpt(self, epoch_num: int, losses: dict[str, float]): { 'epoch': epoch_num, 'world_model_state_dict': self.world_model.state_dict(), - 'world_model_optimizer_state_dict':self.world_model_optimizer.state_dict(), + 'world_model_optimizer_state_dict': self.world_model_optimizer.state_dict(), 'actor_state_dict': self.actor.state_dict(), 'critic_state_dict': self.critic.state_dict(), 'actor_optimizer_state_dict': self.actor_optimizer.state_dict(), 'critic_optimizer_state_dict': self.critic_optimizer.state_dict(), 'losses': losses - }, f'dreamerV2-{epoch_num}-{sum(losses.values())}.ckpt') + }, f'dreamerV2-{epoch_num}-{losses["total"]}.ckpt') def load_ckpt(self, ckpt_path: Path): ckpt = torch.load(ckpt_path) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 4b7898c..28cebfb 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -1,9 +1,10 @@ _target_: rl_sandbox.agents.DreamerV2 +layer_norm: true # World model parameters batch_cluster_size: 50 latent_dim: 32 latent_classes: 32 -rssm_dim: 256 +rssm_dim: 200 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 1.0 @@ -27,4 +28,4 @@ critic_update_interval: 100 # [0-1], 1 means hard update critic_soft_update_fraction: 1 -discrete_rssm: true +discrete_rssm: false diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 9c246f6..5dc802a 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -1,4 +1,5 @@ _target_: rl_sandbox.agents.DreamerV2 +layer_norm: true # World model parameters batch_cluster_size: 50 latent_dim: 32 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index c5798f2..6d504fb 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,22 +1,26 @@ defaults: - agent: dreamer_v2 - - env: dm_quadruped + - env: dm_cartpole - _self_ seed: 42 device_type: cuda -log_message: Quadruped with rssm + +logger: + message: Cartpole with discrete + log_grads: false training: checkpoint_path: null - steps: 1e6 + steps: 5e5 prefill: 10000 batch_size: 800 pretrain: 100 prioritize_ends: false gradient_steps_per_step: 5 - save_checkpoint_every: 1e6 - val_logs_every: 5e3 + save_checkpoint_every: 2e5 + val_logs_every: 2.5e4 + validation: rollout_num: 5 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 7da1edc..56b9b34 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -65,7 +65,7 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - writer = SummaryWriter(comment=cfg.log_message or "") + writer = SummaryWriter(comment=cfg.logger.message or "") env: Env = hydra.utils.instantiate(cfg.env) val_env: Env = hydra.utils.instantiate(cfg.env) @@ -102,7 +102,8 @@ def main(cfg: DictConfig): losses = agent.train(s, a, r, n, f, first) for loss_name, loss in losses.items(): if 'grad' in loss_name: - writer.add_histogram(f'pre_train/{loss_name}', loss, i) + if cfg.logger.log_grads: + writer.add_histogram(f'pre_train/{loss_name}', loss, i) else: writer.add_scalar(f'pre_train/{loss_name}', loss.item(), i) @@ -121,7 +122,6 @@ def main(cfg: DictConfig): while global_step < cfg.training.steps: ### Training and exploration - # TODO: add buffer end prioritarization for s, a, r, n, f, _ in iter_rollout(env, agent): buff.add_sample(s, a, r, n, f) @@ -137,7 +137,8 @@ def main(cfg: DictConfig): if global_step % 100 == 0: for loss_name, loss in losses.items(): if 'grad' in loss_name: - writer.add_histogram(f'train/{loss_name}', loss, global_step) + if cfg.logger.log_grads: + writer.add_histogram(f'train/{loss_name}', loss, global_step) else: writer.add_scalar(f'train/{loss_name}', loss.item(), global_step) global_step += cfg.env.repeat_action_num diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index dbabd66..515dabb 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -137,6 +137,8 @@ def __init__(self, task_name: str, run_on_pixels: bool, obs_res: tuple[int, int] super().__init__(run_on_pixels, obs_res, repeat_action_num, transforms) self.task_name = task_name + if self.task_name.startswith('Crafter'): + import crafter self.env: gym.Env = gym.make(task_name) if run_on_pixels: diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index 6200ddf..c20f993 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -6,7 +6,9 @@ def fc_nn_generator(input_num: int, hidden_size: int, num_layers: int, intermediate_activation: t.Type[nn.Module] = nn.ReLU, - final_activation: nn.Module = nn.Identity()): + final_activation: nn.Module = nn.Identity(), + layer_norm: bool = False): + norm_layer = nn.Identity if layer_norm else nn.LayerNorm assert num_layers >= 3 layers = [] layers.append(nn.Linear(input_num, hidden_size)) @@ -14,7 +16,7 @@ def fc_nn_generator(input_num: int, layers.append(intermediate_activation(inplace=True)) for _ in range(num_layers - 2): layers.append(nn.Linear(hidden_size, hidden_size)) - layers.append(nn.LayerNorm(hidden_size)) + layers.append(norm_layer(hidden_size)) layers.append(intermediate_activation(inplace=True)) layers.append(nn.Linear(hidden_size, output_num)) layers.append(final_activation) From 6366540e87b6605aa407f73aa804a0779d9d8c8c Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 23 Mar 2023 10:28:06 +0000 Subject: [PATCH 049/106] Small fixes --- rl_sandbox/agents/dreamer_v2.py | 31 +------- rl_sandbox/config/config.yaml | 5 +- rl_sandbox/train.py | 119 ++++++++++++++---------------- rl_sandbox/utils/logger.py | 61 +++++++++++++++ rl_sandbox/utils/replay_buffer.py | 9 ++- 5 files changed, 127 insertions(+), 98 deletions(-) create mode 100644 rl_sandbox/utils/logger.py diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index ed47b3a..4498e9d 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -222,30 +222,6 @@ def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.transpose(0, 1)) -# class MlpVAE(nn.Module): -# def __init__(self, input_size, latent_size=(8, 8)): -# self.encoder = fc_nn_generator(input_size, -# np.prod(latent_size), -# hidden_size=np.prod(latent_size), -# num_layers=2, -# intermediate_activation=nn.ELU, -# final_activation=nn.Sequential( -# View((-1,) + latent_size), -# DistLayer('onehot') -# )) -# self.decoder = fc_nn_generator(np.prod(latent_size), -# input_size, -# hidden_size=np.prod(latent_size), -# num_layers=2, -# intermediate_activation=nn.ELU, -# final_activation=nn.Identity()) - -# def forward(self, x): -# z = self.encoder(x) -# x_hat = self.decoder(z.rsample()) -# return z, x_hat - - class RSSM(nn.Module): """ Recurrent State Space Model @@ -821,9 +797,9 @@ def viz_log(self, rollout, logger, epoch_num): else: # log mean +- std pass - logger.add_image('val/latent_probs', latent_hist, epoch_num, dataformats='HW') - logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') - logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num, fps=20) + logger.add_image('val/latent_probs', latent_hist, epoch_num) + logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num) + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) @@ -842,7 +818,6 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation if self.is_discrete: a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() r = self.from_np(r) - next_obs = self.preprocess_obs(self.from_np(next_obs)) discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) first_flags = self.from_np(is_first).type(torch.float32) diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 6d504fb..b1121f6 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -7,6 +7,7 @@ seed: 42 device_type: cuda logger: + type: tensorboard message: Cartpole with discrete log_grads: false @@ -14,10 +15,10 @@ training: checkpoint_path: null steps: 5e5 prefill: 10000 - batch_size: 800 + batch_size: 50 pretrain: 100 prioritize_ends: false - gradient_steps_per_step: 5 + train_every: 5 save_checkpoint_every: 2e5 val_logs_every: 2.5e4 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 56b9b34..4c76780 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -1,52 +1,41 @@ -import hydra -import numpy as np -from omegaconf import DictConfig -from torch.utils.tensorboard.writer import SummaryWriter -from tqdm import tqdm import random -import torch -from torch.profiler import profile, ProfilerActivity +import crafter +import hydra import lovely_tensors as lt +import numpy as np +import torch from gym.spaces import Discrete -import crafter +from omegaconf import DictConfig +from torch.profiler import ProfilerActivity, profile +from tqdm import tqdm from rl_sandbox.metrics import MetricsEvaluator from rl_sandbox.utils.env import Env +from rl_sandbox.utils.logger import Logger from rl_sandbox.utils.replay_buffer import ReplayBuffer -from rl_sandbox.utils.rollout_generation import (collect_rollout_num, iter_rollout, - fillup_replay_buffer) - - -class SummaryWriterMock(): - def add_scalar(*args, **kwargs): - pass +from rl_sandbox.utils.rollout_generation import (collect_rollout_num, + fillup_replay_buffer, + iter_rollout) - def add_video(*args, **kwargs): - pass - def add_image(*args, **kwargs): - pass - - -def val_logs(agent, val_cfg: DictConfig, env, global_step, writer): +def val_logs(agent, val_cfg: DictConfig, env: Env, global_step: int, logger: Logger): with torch.no_grad(): rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent) # TODO: make logs visualization in separate process # Possibly make the data loader metrics = MetricsEvaluator().calculate_metrics(rollouts) - for metric_name, metric in metrics.items(): - writer.add_scalar(f'val/{metric_name}', metric, global_step) + logger.log(metrics, global_step, mode='val') if val_cfg.visualize: rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - writer.add_video('val/visualization', video, global_step, fps=20) + logger.add_video('val/visualization', video, global_step, fps=20) # FIXME: Very bad from architecture point with torch.no_grad(): - agent.viz_log(rollout, writer, global_step) + agent.viz_log(rollout, logger, global_step) @hydra.main(version_base="1.2", config_path='config', config_name='config') @@ -65,53 +54,55 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - writer = SummaryWriter(comment=cfg.logger.message or "") + logger = Logger(*cfg.logger) env: Env = hydra.utils.instantiate(cfg.env) val_env: Env = hydra.utils.instantiate(cfg.env) # TOOD: Create maybe some additional validation env if cfg.env.task_name.startswith("Crafter"): val_env.env = crafter.Recorder(val_env.env, - writer.log_dir, - save_stats=True, - save_video=False, - save_episode=False) + logger.log_dir(), + save_stats=True, + save_video=False, + save_episode=False) buff = ReplayBuffer(prioritize_ends=cfg.training.prioritize_ends, - min_ep_len=cfg.agent.get('batch_cluster_size', 1)*(cfg.training.prioritize_ends + 1)) - fillup_replay_buffer(env, buff, max(cfg.training.prefill, cfg.training.batch_size)) + min_ep_len=cfg.agent.get('batch_cluster_size', 1) * + (cfg.training.prioritize_ends + 1)) + fillup_replay_buffer( + env, buff, + max(cfg.training.prefill, + cfg.training.batch_size * cfg.agent.get('batch_cluster_size', 1))) is_discrete = isinstance(env.action_space, Discrete) - agent = hydra.utils.instantiate(cfg.agent, - obs_space_num=env.observation_space.shape, - actions_num = env.action_space.n if is_discrete else env.action_space.shape[0], - action_type='discrete' if is_discrete else 'continuous' , - device_type=cfg.device_type, - logger=writer) - - prof = profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], - on_trace_ready=torch.profiler.tensorboard_trace_handler('runs/profile_dreamer'), - schedule=torch.profiler.schedule(wait=10, warmup=10, active=5, repeat=5), - with_stack=True) if cfg.debug.profiler else None + agent = hydra.utils.instantiate( + cfg.agent, + obs_space_num=env.observation_space.shape, + actions_num=env.action_space.n if is_discrete else env.action_space.shape[0], + action_type='discrete' if is_discrete else 'continuous', + device_type=cfg.device_type, + logger=logger) + + prof = profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], + on_trace_ready=torch.profiler.tensorboard_trace_handler(logger.log_dir() + '/profiler'), + schedule=torch.profiler.schedule(wait=10, warmup=10, active=5, repeat=5), + with_stack=True) if cfg.debug.profiler else None for i in tqdm(range(int(cfg.training.pretrain)), desc='Pretraining'): if cfg.training.checkpoint_path is not None: break s, a, r, n, f, first = buff.sample(cfg.training.batch_size, - cluster_size=cfg.agent.get('batch_cluster_size', 1)) + cluster_size=cfg.agent.get( + 'batch_cluster_size', 1)) losses = agent.train(s, a, r, n, f, first) - for loss_name, loss in losses.items(): - if 'grad' in loss_name: - if cfg.logger.log_grads: - writer.add_histogram(f'pre_train/{loss_name}', loss, i) - else: - writer.add_scalar(f'pre_train/{loss_name}', loss.item(), i) + logger.log(losses, i, mode='pre_train') # TODO: remove constants log_every_n = 25 - st = int(cfg.training.pretrain) // log_every_n if i % log_every_n == 0: - val_logs(agent, cfg.validation, val_env, -st + i/log_every_n, writer) + st = int(cfg.training.pretrain) // log_every_n + val_logs(agent, cfg.validation, val_env, -st + i // log_every_n, logger) if cfg.training.checkpoint_path is not None: prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) @@ -125,32 +116,31 @@ def main(cfg: DictConfig): for s, a, r, n, f, _ in iter_rollout(env, agent): buff.add_sample(s, a, r, n, f) - if global_step % cfg.training.gradient_steps_per_step == 0: + if global_step % cfg.training.train_every == 0: # NOTE: unintuitive that batch_size is now number of total # samples, but not amount of sequences for recurrent model s, a, r, n, f, first = buff.sample(cfg.training.batch_size, - cluster_size=cfg.agent.get('batch_cluster_size', 1)) + cluster_size=cfg.agent.get( + 'batch_cluster_size', 1)) losses = agent.train(s, a, r, n, f, first) if cfg.debug.profiler: prof.step() if global_step % 100 == 0: - for loss_name, loss in losses.items(): - if 'grad' in loss_name: - if cfg.logger.log_grads: - writer.add_histogram(f'train/{loss_name}', loss, global_step) - else: - writer.add_scalar(f'train/{loss_name}', loss.item(), global_step) + logger.log(losses, global_step, mode='train') + global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) # FIXME: find more appealing solution ### Validation - if (global_step % cfg.training.val_logs_every) < (prev_global_step % cfg.training.val_logs_every): - val_logs(agent, cfg.validation, val_env, global_step, writer) + if (global_step % cfg.training.val_logs_every) < (prev_global_step % + cfg.training.val_logs_every): + val_logs(agent, cfg.validation, val_env, global_step, logger) ### Checkpoint - if (global_step % cfg.training.save_checkpoint_every) < (prev_global_step % cfg.training.save_checkpoint_every): + if (global_step % cfg.training.save_checkpoint_every) < ( + prev_global_step % cfg.training.save_checkpoint_every): agent.save_ckpt(global_step, losses) prev_global_step = global_step @@ -161,4 +151,3 @@ def main(cfg: DictConfig): if __name__ == "__main__": main() - diff --git a/rl_sandbox/utils/logger.py b/rl_sandbox/utils/logger.py new file mode 100644 index 0000000..c263049 --- /dev/null +++ b/rl_sandbox/utils/logger.py @@ -0,0 +1,61 @@ +from torch.utils.tensorboard.writer import SummaryWriter +import typing as t + + +class SummaryWriterMock(): + def __init__(self): + self.log_dir = None + + def add_scalar(*args, **kwargs): + pass + + def add_video(*args, **kwargs): + pass + + def add_image(*args, **kwargs): + pass + + def add_histogram(*args, **kwargs): + pass + + def add_figure(*args, **kwargs): + pass + + +class Logger: + def __init__(self, type: t.Optional[str], + message: t.Optional[str] = None, + log_grads: bool = True) -> None: + self.type = type + match type: + case "tensorboard": + self.writer = SummaryWriter(comment=message or "") + case None: + self.writer = SummaryWriterMock() + case _: + raise ValueError(f"Unknown logger type: {type}") + self.log_grads = log_grads + + + def log(self, losses: dict[str, t.Any], global_step: int, mode: str = 'train'): + for loss_name, loss in losses.items(): + if 'grad' in loss_name: + if self.log_grads: + self.writer.add_histogram(f'train/{loss_name}', loss, global_step) + else: + self.writer.add_scalar(f'train/{loss_name}', loss.item(), global_step) + + def add_scalar(self, name: str, value: t.Any, global_step: int): + self.writer.add_scalar(name, value, global_step) + + def add_image(self, name: str, image: t.Any, global_step: int): + self.writer.add_image(name, image, global_step, dataformats='HW') + + def add_video(self, name: str, video: t.Any, global_step: int): + self.writer.add_video(name, video, global_step, fps=20) + + def add_figure(self, name: str, figure: t.Any, global_step: int): + self.writer.add_figure(name, figure, global_step) + + def log_dir(self) -> str: + return self.writer.log_dir diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 6a2716c..4572376 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -32,7 +32,10 @@ def __len__(self): # TODO: make buffer concurrent-friendly class ReplayBuffer: - def __init__(self, max_len=2e6, prioritize_ends: bool = False, min_ep_len: int = 1): + def __init__(self, max_len=2e6, + prioritize_ends: bool = False, + min_ep_len: int = 1, + device: str = 'cpu'): self.rollouts: deque[Rollout] = deque() self.rollouts_len: deque[int] = deque() self.curr_rollout = None @@ -40,6 +43,7 @@ def __init__(self, max_len=2e6, prioritize_ends: bool = False, min_ep_len: int = self.prioritize_ends = prioritize_ends self.max_len = max_len self.total_num = 0 + self.device = device def __len__(self): return self.total_num @@ -87,13 +91,12 @@ def sample( batch_size: int, cluster_size: int = 1 ) -> tuple[States, Actions, Rewards, States, TerminationFlags, IsFirstFlags]: - seq_num = batch_size // cluster_size # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me s, a, r, n, t, is_first = [], [], [], [], [], [] do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > (cluster_size * (self.prioritize_ends + 1)) tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), - seq_num, + batch_size, p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / tot) s_indeces = [] for r_idx in r_indeces: From 3b9906fab1eaf76b78bc9fa66461f8c01a9b99e6 Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 23 Mar 2023 15:35:28 +0000 Subject: [PATCH 050/106] Added reconstruction of DINO features --- rl_sandbox/agents/dreamer_v2.py | 20 +- rl_sandbox/config/env/dm_manipulator.yaml | 8 + rl_sandbox/vision/dino.py | 360 ++++++++++++++++++++++ 3 files changed, 385 insertions(+), 3 deletions(-) create mode 100644 rl_sandbox/config/env/dm_manipulator.yaml create mode 100644 rl_sandbox/vision/dino.py diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 4498e9d..a9fd502 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -11,6 +11,7 @@ from torch import nn from torch.nn import functional as F from jaxtyping import Float, Bool +from rl_sandbox.vision.dino import ViTFeat from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator @@ -426,8 +427,19 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, discrete_rssm, norm_layer=nn.Identity if layer_norm else nn.LayerNorm) self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 768, 'base', 'k', 8) + self.dino_vit.requires_grad_(False) + self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + 64*768, + hidden_size=2048, + num_layers=3, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, 1, hidden_size=400, @@ -517,12 +529,14 @@ def KL(dist1, dist2, free_nat = True): r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) - x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits - - losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + d_features = self.dino_vit(ToTensor(obs + 0.5))[0] + losses['loss_reconstruction'] = -d_pred.log_prob(d_features.flatten(1, 2)).float().mean() losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) diff --git a/rl_sandbox/config/env/dm_manipulator.yaml b/rl_sandbox/config/env/dm_manipulator.yaml new file mode 100644 index 0000000..6cbeea6 --- /dev/null +++ b/rl_sandbox/config/env/dm_manipulator.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: manipulator +task_name: bring_ball +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py new file mode 100644 index 0000000..791e6af --- /dev/null +++ b/rl_sandbox/vision/dino.py @@ -0,0 +1,360 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Taken from YangtaoWANG95/TokenCut/unsupervised_saliency_detection/dino.py + +Copied from Dino repo. https://github.com/facebookresearch/dino +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial +import warnings + +import torch +import torch.nn as nn + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + + + +class ViTFeat(nn.Module): + """ Vision Transformer """ + def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16): + super().__init__() + if vit_arch == 'base' : + self.model = vit_base(patch_size=patch_size, num_classes=0) + + else : + self.model = vit_small(patch_size=patch_size, num_classes=0) + + self.feat_dim = feat_dim + self.vit_feat = vit_feat + self.patch_size = patch_size + +# state_dict = torch.load(pretrained_pth, map_location="cpu") + state_dict = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com"+pretrained_pth) + self.model.load_state_dict(state_dict, strict=True) + print('Loading weight from {}'.format(pretrained_pth)) + + + def forward(self, img) : + feat_out = {} + def hook_fn_forward_qkv(module, input, output): + feat_out["qkv"] = output + + self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) + + + # Forward pass in the model + with torch.no_grad() : + h, w = img.shape[2], img.shape[3] + feat_h, feat_w = h // self.patch_size, w // self.patch_size + attentions = self.model.get_last_selfattention(img) + bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] + qkv = ( + feat_out["qkv"] + .reshape(bs, nb_token, 3, nb_head, -1) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k.transpose(1, 2).reshape(bs, nb_token, -1) + q = q.transpose(1, 2).reshape(bs, nb_token, -1) + v = v.transpose(1, 2).reshape(bs, nb_token, -1) + + # Modality selection + if self.vit_feat == "k": + feats = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + elif self.vit_feat == "q": + feats = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + elif self.vit_feat == "v": + feats = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + elif self.vit_feat == "kqv": + k = k[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + q = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + v = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) + feats = torch.cat([k, q, v], dim=1) + return feats, attentions + + +if __name__ == "__main__": + model = ViTFeat('/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', 64, 'base', 'k', patch_size=8) + img = torch.FloatTensor(4, 3, 64, 64) + # Forward pass in the model + feat = model(img) + print (feat[0].shape) From 0d3a6b08d634fadde97bac23682c23129aa491fd Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 4 Apr 2023 11:07:49 +0100 Subject: [PATCH 051/106] tuned parametrs for dino features, added convolutional vit decoder --- .vimspector.json | 7 + pyproject.toml | 4 +- rl_sandbox/agents/dreamer_v2.py | 143 +++++++++++--- rl_sandbox/config/agent/dreamer_v2.yaml | 10 +- .../config/agent/dreamer_v2_crafter.yaml | 8 +- rl_sandbox/config/config.yaml | 14 +- rl_sandbox/config/env/dm_acrobot.yaml | 8 + rl_sandbox/config/env/dm_cartpole.yaml | 3 +- rl_sandbox/config/training/crafter.yaml | 8 + rl_sandbox/config/training/dm.yaml | 8 + rl_sandbox/train.py | 12 +- rl_sandbox/vision/dino.py | 2 +- rl_sandbox/vision/my_slot_attention.py | 185 ++++++++++++++++++ rl_sandbox/vision/slot_attention.py | 69 +++++++ 14 files changed, 428 insertions(+), 53 deletions(-) create mode 100644 rl_sandbox/config/env/dm_acrobot.yaml create mode 100644 rl_sandbox/config/training/crafter.yaml create mode 100644 rl_sandbox/config/training/dm.yaml create mode 100644 rl_sandbox/vision/my_slot_attention.py create mode 100644 rl_sandbox/vision/slot_attention.py diff --git a/.vimspector.json b/.vimspector.json index 4dfda9e..3945f69 100644 --- a/.vimspector.json +++ b/.vimspector.json @@ -42,6 +42,13 @@ "program": "rl_sandbox/train.py", "args": [] } + }, + "Run dino": { + "extends": "python-base", + "configuration": { + "program": "rl_sandbox/vision/my_slot_attention.py", + "args": [] + } } } } diff --git a/pyproject.toml b/pyproject.toml index 9644b40..94c5203 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,8 @@ nptyping = '*' gym = "0.25.0" # crafter requires old step api pygame = '*' moviepy = '*' -torchvision = '^0.13' -torch = '^1.12' +torchvision = '*' +torch = '*' tensorboard = '^2.0' dm-control = '^1.0.0' unpackable = '^0.0.4' diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index a9fd502..75bfe23 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -10,6 +10,7 @@ import torch.distributions as td from torch import nn from torch.nn import functional as F +import torchvision as tv from jaxtyping import Float, Bool from rl_sandbox.vision.dino import ViTFeat @@ -385,6 +386,35 @@ def forward(self, X): x = x.view(-1, 32 * self.channel_step, 1, 1) return td.Independent(td.Normal(self.net(x), 1.0), 3) +class ViTDecoder(nn.Module): + + def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3]): + super().__init__() + layers = [] + self.channel_step = 12 + # 2**(len(kernel_sizes)-1)*channel_step + self.convin = nn.Linear(input_size, 32 * self.channel_step) + + in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step + for i, k in enumerate(kernel_sizes): + out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step + if i == len(kernel_sizes) - 1: + out_channels = 3 + layers.append(nn.ConvTranspose2d(in_channels, 384, kernel_size=k, stride=1, padding=1)) + else: + layers.append(norm_layer(1, in_channels)) + layers.append( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2, padding=2, output_padding=1)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, 32 * self.channel_step, 1, 1) + return td.Independent(td.Normal(self.net(x), 1.0), 3) + + class Normalizer(nn.Module): def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): @@ -407,7 +437,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount, layer_norm: bool): + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool): super().__init__() self.kl_free_nats = kl_free_nats self.kl_beta = kl_loss_scale @@ -419,6 +449,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # kl loss balancing (prior/posterior) self.alpha = kl_loss_balancing self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit self.recurrent_model = RSSM(latent_dim, rssm_dim, @@ -426,17 +458,36 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, latent_classes, discrete_rssm, norm_layer=nn.Identity if layer_norm else nn.LayerNorm) - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 768, 'base', 'k', 8) - self.dino_vit.requires_grad_(False) - self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, - 64*768, - hidden_size=2048, - num_layers=3, - intermediate_activation=nn.ELU, - layer_norm=layer_norm, - final_activation=DistLayer('mse')) - + if encode_vit or decode_vit: + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 768, 'base', 'k', 8) + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", 384, 'small', 'k', 8) + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.encoder = nn.Sequential( + self.dino_vit, + nn.Flatten(), + # fc_nn_generator(64*self.dino_vit.feat_dim, + # 64*384, + # hidden_size=400, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm) + ) + else: + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + + + if decode_vit: + self.dino_predictor = ViTDecoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + # 64*self.dino_vit.feat_dim, + # hidden_size=2048, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm, + # final_activation=DistLayer('mse')) self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.Identity if layer_norm else nn.GroupNorm) @@ -529,14 +580,24 @@ def KL(dist1, dist2, free_nat = True): r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) - d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + if not self.decode_vit: + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + losses['loss_reconstruction_img'] = 0 + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) + d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + inp = obs + if not self.encode_vit: + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + inp = ToTensor(obs + 0.5) + d_features = self.dino_vit(inp) + losses['loss_reconstruction_img'] = -x_r.log_prob(obs).float().mean() + losses['loss_reconstruction'] = -d_pred.log_prob(d_features.reshape(b, 384, 8, 8)).float().mean()/2 prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - d_features = self.dino_vit(ToTensor(obs + 0.5))[0] - losses['loss_reconstruction'] = -d_pred.log_prob(d_features.flatten(1, 2)).float().mean() losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) @@ -635,6 +696,8 @@ def __init__( critic_lr: float, discrete_rssm: bool, layer_norm: bool, + encode_vit: bool, + decode_vit: bool, device_type: str = 'cpu', logger = None): @@ -653,7 +716,8 @@ def __init__( rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_loss_free_nats, discrete_rssm, - world_model_predict_discount, layer_norm).to(device_type) + world_model_predict_discount, layer_norm, + encode_vit, decode_vit).to(device_type) self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num if self.is_discrete else actions_num * 2, @@ -670,6 +734,10 @@ def __init__( layer_norm=layer_norm).to(device_type) self.scaler = torch.cuda.amp.GradScaler() + self.image_predictor_optimizer = torch.optim.AdamW(self.world_model.image_predictor.parameters(), + lr=world_model_lr, + eps=1e-5, + weight_decay=1e-6) self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), lr=world_model_lr, eps=1e-5, @@ -721,13 +789,17 @@ def reset(self): self._action_probs = torch.zeros((self.actions_num), device=self.device) self._stored_steps = 0 - @staticmethod - def preprocess_obs(obs: torch.Tensor): + def preprocess_obs(self, obs: torch.Tensor): # FIXME: move to dataloader in replay buffer order = list(range(len(obs.shape))) # Swap channel from last to 3 from last order = order[:-3] + [order[-1]] + order[-3:-1] - return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) + if self.world_model.encode_vit: + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + return ToTensor(obs.type(torch.float32).permute(order)) + else: + return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) # return obs.type(torch.float32).permute(order) def get_action(self, obs: Observation) -> Action: @@ -760,13 +832,20 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ rews = [] state = None + means = np.array([0.485, 0.456, 0.406]) + stds = np.array([0.229, 0.224, 0.225]) + UnNormalize = tv.transforms.Normalize(list(-means/stds), + list(1/stds)) for idx, (o, a) in enumerate(list(zip(obs, actions))): if idx > update_num: break state = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) video_r = self.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() rews.append(self.world_model.reward_predictor(state.combined).mode.item()) - video_r = (video_r + 0.5) + if self.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) video.append(video_r) rews = torch.Tensor(rews).to(obs.device) @@ -775,7 +854,10 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ states, _, rews_2, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) video_r = self.world_model.image_predictor(states.combined[1:]).mode.cpu().detach().numpy() - video_r = (video_r + 0.5) + if self.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) video.append(video_r) return np.concatenate(video), rews @@ -843,20 +925,27 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device world_model_loss = torch.Tensor(0).to(self.device) + image_predictor_loss = losses['loss_reconstruction_img'] world_model_loss = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + losses['loss_kl_reg'] + losses['loss_discount_pred']) # for l in losses.values(): # world_model_loss += l + if self.world_model.decode_vit: + self.image_predictor_optimizer.zero_grad(set_to_none=True) + self.scaler.scale(image_predictor_loss).backward() + self.scaler.unscale_(self.image_predictor_optimizer) + nn.utils.clip_grad_norm_(self.world_model.image_predictor.parameters(), 100) + self.scaler.step(self.image_predictor_optimizer) self.world_model_optimizer.zero_grad(set_to_none=True) self.scaler.scale(world_model_loss).backward() # FIXME: clip gradient should be parametrized self.scaler.unscale_(self.world_model_optimizer) - for tag, value in self.world_model.named_parameters(): - wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() + # for tag, value in self.world_model.named_parameters(): + # wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) @@ -927,8 +1016,8 @@ def calculate_entropy(dist): nn.utils.clip_grad_norm_(self.actor.parameters(), 100) nn.utils.clip_grad_norm_(self.critic.parameters(), 100) - for tag, value in self.actor.named_parameters(): - wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() + # for tag, value in self.actor.named_parameters(): + # wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() self.scaler.step(self.actor_optimizer) self.scaler.step(self.critic_optimizer) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 28cebfb..82e2c78 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -1,5 +1,5 @@ _target_: rl_sandbox.agents.DreamerV2 -layer_norm: true +layer_norm: false # World model parameters batch_cluster_size: 50 latent_dim: 32 @@ -7,7 +7,7 @@ latent_classes: 32 rssm_dim: 200 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 -kl_loss_free_nats: 1.0 +kl_loss_free_nats: 0.25 world_model_lr: 3e-4 world_model_predict_discount: false @@ -15,13 +15,13 @@ world_model_predict_discount: false discount_factor: 0.999 imagination_horizon: 15 -actor_lr: 8e-5 +actor_lr: 3e-4 # mixing of reinforce and maximizing value func # for dm_control it is zero in Dreamer (Atari 1) actor_reinforce_fraction: null actor_entropy_scale: 1e-4 -critic_lr: 8e-5 +critic_lr: 3e-4 # Lambda parameter for trainin deeper multi-step prediction critic_value_target_lambda: 0.95 critic_update_interval: 100 @@ -29,3 +29,5 @@ critic_update_interval: 100 critic_soft_update_fraction: 1 discrete_rssm: false +decode_vit: true +encode_vit: false diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 5dc802a..8b93550 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -8,19 +8,19 @@ rssm_dim: 1024 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 0.0 -world_model_lr: 1e-4 +world_model_lr: 3e-4 world_model_predict_discount: true # ActorCritic parameters discount_factor: 0.999 imagination_horizon: 15 -actor_lr: 1e-4 +actor_lr: 3e-4 # automatically chooses depending on discrete/continuous env actor_reinforce_fraction: null actor_entropy_scale: 3e-3 -critic_lr: 1e-4 +critic_lr: 3e-4 # Lambda parameter for trainin deeper multi-step prediction critic_value_target_lambda: 0.95 critic_update_interval: 100 @@ -28,3 +28,5 @@ critic_update_interval: 100 critic_soft_update_fraction: 1 discrete_rssm: false +decode_vit: true +encode_vit: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index b1121f6..63f6a94 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,6 +1,7 @@ defaults: - agent: dreamer_v2 - env: dm_cartpole + - training: dm - _self_ seed: 42 @@ -8,19 +9,16 @@ device_type: cuda logger: type: tensorboard - message: Cartpole with discrete + message: Cartpole with convolutional decoder + #message: test_last log_grads: false training: checkpoint_path: null - steps: 5e5 - prefill: 10000 + steps: 1e6 batch_size: 50 - pretrain: 100 - prioritize_ends: false - train_every: 5 - save_checkpoint_every: 2e5 - val_logs_every: 2.5e4 + val_logs_every: 2e4 + #val_logs_every: 2.5e3 validation: diff --git a/rl_sandbox/config/env/dm_acrobot.yaml b/rl_sandbox/config/env/dm_acrobot.yaml new file mode 100644 index 0000000..313a1f6 --- /dev/null +++ b/rl_sandbox/config/env/dm_acrobot.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: acrobot +task_name: swingup +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_cartpole.yaml b/rl_sandbox/config/env/dm_cartpole.yaml index 4bd3888..5cb6345 100644 --- a/rl_sandbox/config/env/dm_cartpole.yaml +++ b/rl_sandbox/config/env/dm_cartpole.yaml @@ -5,5 +5,4 @@ run_on_pixels: true obs_res: [64, 64] camera_id: 0 repeat_action_num: 2 -transforms: - - _target_: rl_sandbox.utils.env.ActionNormalizer +transforms: [] diff --git a/rl_sandbox/config/training/crafter.yaml b/rl_sandbox/config/training/crafter.yaml new file mode 100644 index 0000000..85b4242 --- /dev/null +++ b/rl_sandbox/config/training/crafter.yaml @@ -0,0 +1,8 @@ +steps: 1e6 +prefill: 10000 +batch_size: 16 +pretrain: 1 +prioritize_ends: true +train_every: 5 +save_checkpoint_every: 2e5 +val_logs_every: 2e4 diff --git a/rl_sandbox/config/training/dm.yaml b/rl_sandbox/config/training/dm.yaml new file mode 100644 index 0000000..fc2bfc1 --- /dev/null +++ b/rl_sandbox/config/training/dm.yaml @@ -0,0 +1,8 @@ +steps: 1e6 +prefill: 1000 +batch_size: 16 +pretrain: 100 +prioritize_ends: false +train_every: 5 +save_checkpoint_every: 2e5 +val_logs_every: 2e4 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 4c76780..150c83b 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -32,7 +32,7 @@ def val_logs(agent, val_cfg: DictConfig, env: Env, global_step: int, logger: Log for rollout in rollouts: video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - logger.add_video('val/visualization', video, global_step, fps=20) + logger.add_video('val/visualization', video, global_step) # FIXME: Very bad from architecture point with torch.no_grad(): agent.viz_log(rollout, logger, global_step) @@ -54,7 +54,7 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) - logger = Logger(*cfg.logger) + logger = Logger(**cfg.logger) env: Env = hydra.utils.instantiate(cfg.env) val_env: Env = hydra.utils.instantiate(cfg.env) @@ -99,10 +99,10 @@ def main(cfg: DictConfig): logger.log(losses, i, mode='pre_train') # TODO: remove constants - log_every_n = 25 - if i % log_every_n == 0: - st = int(cfg.training.pretrain) // log_every_n - val_logs(agent, cfg.validation, val_env, -st + i // log_every_n, logger) + # log_every_n = 25 + # if i % log_every_n == 0: + # st = int(cfg.training.pretrain) // log_every_n + val_logs(agent, cfg.validation, val_env, -1, logger) if cfg.training.checkpoint_path is not None: prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py index 791e6af..bcd24e3 100644 --- a/rl_sandbox/vision/dino.py +++ b/rl_sandbox/vision/dino.py @@ -349,7 +349,7 @@ def hook_fn_forward_qkv(module, input, output): q = q[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) v = v[:, 1:].transpose(1, 2).reshape(bs, self.feat_dim, feat_h * feat_w) feats = torch.cat([k, q, v], dim=1) - return feats, attentions + return feats if __name__ == "__main__": diff --git a/rl_sandbox/vision/my_slot_attention.py b/rl_sandbox/vision/my_slot_attention.py new file mode 100644 index 0000000..4a4ffd5 --- /dev/null +++ b/rl_sandbox/vision/my_slot_attention.py @@ -0,0 +1,185 @@ +import torch +import typing as t +from torch import nn +import torch.nn.functional as F +from jaxtyping import Float +import torchvision as tv +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +import numpy as np + +from rl_sandbox.vision.vae import ResBlock + +class SlotAttention(nn.Module): + def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int): + super().__init__() + + self.seq_num = seq_num + self.n_slots = num_slots + self.n_iter = n_iter + self.n_dim = n_dim + self.scale = self.n_dim**(-1/2) + self.epsilon = 1e-8 + + # self.norm = LayerNorm() + self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + nn.init.xavier_uniform_(self.slots_logsigma) + + self.slots_proj = nn.Linear(n_dim, n_dim) + self.slots_proj_2 = nn.Linear(n_dim, n_dim) + self.slots_norm = nn.LayerNorm(self.n_dim) + self.slots_norm_2 = nn.LayerNorm(self.n_dim) + self.slots_reccur = nn.GRUCell(input_size=self.n_dim, hidden_size=self.n_dim) + + self.inputs_proj = nn.Linear(n_dim, n_dim*2, ) + self.inputs_norm = nn.LayerNorm(self.n_dim) + + def forward(self, X: Float[torch.Tensor, 'batch seq h w']) -> Float[torch.Tensor, 'batch num_slots n_dim']: + batch, seq, _, _ = X.shape + slots = self.slots_logsigma.exp() + self.slots_mu*torch.randn(batch, self.n_slots, self.n_dim, device=X.device) + + for _ in range(self.n_iter): + slots_prev = slots + k, v = self.inputs_proj(self.inputs_norm(X).permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)).chunk(2, dim=-1) + slots = self.slots_norm(slots) + q = self.slots_proj(slots) + + attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k), dim=1) + self.epsilon + + attn = attn / attn.sum(dim=-1, keepdim=True) + + updates = torch.einsum('bij,bjk->bik', attn, v) / self.n_slots + slots = self.slots_reccur(updates.reshape(-1, self.n_dim), slots_prev.reshape(-1, self.n_dim)).reshape(batch, self.n_slots, self.n_dim) + slots = slots + self.slots_proj_2(self.slots_norm_2(slots)) + return slots + +def build_grid(resolution): + ranges = [np.linspace(0., 1., num=res) for res in resolution] + grid = np.meshgrid(*ranges, sparse=False, indexing="ij") + grid = np.stack(grid, axis=-1) + grid = np.reshape(grid, [resolution[0], resolution[1], -1]) + grid = np.expand_dims(grid, axis=0) + grid = grid.astype(np.float32) + return np.concatenate([grid, 1.0 - grid], axis=-1) + + +class PositionalEmbedding(nn.Module): + def __init__(self, n_dim: int): + super().__init__() + self.n_dim = n_dim + self.proj = nn.Linear(4, n_dim) + self.grid = torch.from_numpy(build_grid((64, 64))) + + def forward(self, X) -> torch.Tensor: + return X + self.proj(self.grid.to(X.device)) + # xs = torch.arange(self.n_dim) + # pos = torch.zeros_like(xs, dtype=torch.float) + # for i in range(self.n_dim): + # pos += torch.sin(xs / (1e4**(2*i/self.n_dim))) + # return X + pos + +class SlottedAutoEncoder(nn.Module): + def __init__(self, num_slots: int, n_iter: int): + super().__init__() + in_channels = 3 + latent_dim = 16 + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=5, padding='same'), + nn.LeakyReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding='same'), + nn.LeakyReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding='same'), + nn.LeakyReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding='same'), + nn.LeakyReLU(inplace=True), + ) + seq_num = latent_dim + n_dim = 64 + self.positional_augmenter = PositionalEmbedding(n_dim) + self.slot_attention = SlotAttention(num_slots, seq_num, n_dim, n_iter) + self.decoder = nn.Sequential( + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.LeakyReLU(inplace=True), + nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1), + ) + + def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch.Tensor, 'batch 3 h w'], Float[torch.Tensor, 'batch num_slots 4 h w']]: + features = self.encoder(X) # -> batch D 64 + features_with_pos_enc = self.positional_augmenter(features) # -> batch D 64 + + slots = self.slot_attention(features_with_pos_enc) # -> batch num_slots 64 + slots = slots.flatten(0, 1).reshape(-1, 1, 1, 64) + slots = slots.repeat((1, 8, 8, 1)) + + decoded_imgs, masks = self.decoder(slots.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).reshape(X.shape[0], -1, *X.shape[2:], 4).split([3, 1], dim=-1) + + decoded_imgs = decoded_imgs * F.softmax(masks, dim=1) + rec_img = torch.sum(decoded_imgs, dim=1) + return rec_img.permute(0, 3, 1, 2), decoded_imgs.permute(0, 1, 4, 2, 3) + +if __name__ == '__main__': + device = 'cuda' + ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), + tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ] + ) + train_data = tv.datasets.ImageFolder('~/rl_old/rl_sandbox/crafter_data_2/', transform=ToTensor) + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=64, + shuffle=True, + num_workers=8) + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}", comment="Added lr scheduler") + + number_of_slots = 7 + slots_iter_num = 3 + + total_steps = 5e5 + warmup_steps = 1e4 + decay_rate = 0.5 + decay_steps = 1e5 + + model = SlottedAutoEncoder(number_of_slots, slots_iter_num).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) + lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) + lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) + # lr_decay_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate**(1/decay_steps)) + lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) + + global_step = 0 + epoch = 0 + pbar = tqdm(total=total_steps, desc='Training') + while global_step < total_steps: + epoch += 1 + logger.add_scalar('epoch', epoch, epoch) + + for sample_num, (img, target) in enumerate(train_data_loader): + recovered_img, _ = model(img.to(device)) + + reg_loss = F.mse_loss(img.to(device), recovered_img) + + optimizer.zero_grad() + reg_loss.backward() + optimizer.step() + lr_scheduler.step() + + logger.add_scalar('train/loss', reg_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + pbar.update(1) + + for i in range(3): + img, target = next(iter(train_data_loader)) + recovered_img, imgs_per_slot = model(img.to(device)) + unnormalize = tv.transforms.Compose([ + tv.transforms.Normalize((0, 0, 0), (1/0.229, 1/0.224, 1/0.225)), + tv.transforms.Normalize((-0.485, -0.456, -0.406), (1., 1., 1.)) + ]) + logger.add_image(f'val/example_image', unnormalize(img.cpu().detach()[0]), epoch*3 + i) + logger.add_image(f'val/example_image_rec', unnormalize(recovered_img.cpu().detach()[0]), epoch*3 + i) + for i in range(6): + logger.add_image(f'val/example_image_rec_{i}', unnormalize(imgs_per_slot.cpu().detach()[0][i]), epoch*3 + i) diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py new file mode 100644 index 0000000..4caefe9 --- /dev/null +++ b/rl_sandbox/vision/slot_attention.py @@ -0,0 +1,69 @@ +import torch +from torch import nn +from torch.nn import init + +class SlotAttention(nn.Module): + def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128): + super().__init__() + self.num_slots = num_slots + self.iters = iters + self.eps = eps + self.scale = dim ** -0.5 + + self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) + + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim)) + init.xavier_uniform_(self.slots_logsigma) + + self.to_q = nn.Linear(dim, dim) + self.to_k = nn.Linear(dim, dim) + self.to_v = nn.Linear(dim, dim) + + self.gru = nn.GRUCell(dim, dim) + + hidden_dim = max(dim, hidden_dim) + + self.mlp = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.ReLU(inplace = True), + nn.Linear(hidden_dim, dim) + ) + + self.norm_input = nn.LayerNorm(dim) + self.norm_slots = nn.LayerNorm(dim) + self.norm_pre_ff = nn.LayerNorm(dim) + + def forward(self, inputs, num_slots = None): + b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype + n_s = num_slots if num_slots is not None else self.num_slots + + mu = self.slots_mu.expand(b, n_s, -1) + sigma = self.slots_logsigma.exp().expand(b, n_s, -1) + + slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype) + + inputs = self.norm_input(inputs) + k, v = self.to_k(inputs), self.to_v(inputs) + + for _ in range(self.iters): + slots_prev = slots + + slots = self.norm_slots(slots) + q = self.to_q(slots) + + dots = torch.einsum('bid,bjd->bij', q, k) * self.scale + attn = dots.softmax(dim=1) + self.eps + + attn = attn / attn.sum(dim=-1, keepdim=True) + + updates = torch.einsum('bjd,bij->bid', v, attn) + + slots = self.gru( + updates.reshape(-1, d), + slots_prev.reshape(-1, d) + ) + + slots = slots.reshape(b, -1, d) + slots = slots + self.mlp(self.norm_pre_ff(slots)) + + return slots From 2d2b658b1c40ee7b84ddb958037659f33e52d906 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 4 Apr 2023 17:25:22 +0000 Subject: [PATCH 052/106] Added upsample before ViT --- rl_sandbox/agents/dreamer_v2.py | 51 +++++++++++++++---------- rl_sandbox/config/agent/dreamer_v2.yaml | 1 + rl_sandbox/config/config.yaml | 10 +++-- rl_sandbox/train.py | 2 + rl_sandbox/vision/dino.py | 12 +++--- 5 files changed, 46 insertions(+), 30 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 75bfe23..7e75a52 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -388,7 +388,7 @@ def forward(self, X): class ViTDecoder(nn.Module): - def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3]): + def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): super().__init__() layers = [] self.channel_step = 12 @@ -435,9 +435,9 @@ def update(self, x): class WorldModel(nn.Module): - def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool): + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float): super().__init__() self.kl_free_nats = kl_free_nats self.kl_beta = kl_loss_scale @@ -451,6 +451,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.predict_discount = predict_discount self.encode_vit = encode_vit self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio self.recurrent_model = RSSM(latent_dim, rssm_dim, @@ -459,8 +460,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, discrete_rssm, norm_layer=nn.Identity if layer_norm else nn.LayerNorm) if encode_vit or decode_vit: - # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 768, 'base', 'k', 8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", 384, 'small', 'k', 8) + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) if encode_vit: @@ -580,21 +583,32 @@ def KL(dist1, dist2, free_nat = True): r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = 0 + if not self.decode_vit: x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) - losses['loss_reconstruction_img'] = 0 losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() else: - x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) + if self.vit_l2_ratio != 1.0: + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = 0 + x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) inp = obs if not self.encode_vit: - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(224, antialias=True)]) + # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)) inp = ToTensor(obs + 0.5) d_features = self.dino_vit(inp) - losses['loss_reconstruction_img'] = -x_r.log_prob(obs).float().mean() - losses['loss_reconstruction'] = -d_pred.log_prob(d_features.reshape(b, 384, 8, 8)).float().mean()/2 + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 28, 28)).float().mean()/2 + + (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits @@ -673,7 +687,7 @@ class DreamerV2(RlAgent): def __init__( self, - obs_space_num: int, # NOTE: encoder/decoder will work only with 64x64 currently + obs_space_num: list[int], # NOTE: encoder/decoder will work only with 64x64 currently actions_num: int, action_type: str, batch_cluster_size: int, @@ -698,6 +712,7 @@ def __init__( layer_norm: bool, encode_vit: bool, decode_vit: bool, + vit_l2_ratio: float, device_type: str = 'cpu', logger = None): @@ -712,12 +727,12 @@ def __init__( if self.rho is None: self.rho = self.is_discrete - self.world_model = WorldModel(batch_cluster_size, latent_dim, latent_classes, + self.world_model = WorldModel(obs_space_num[0], batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_loss_free_nats, discrete_rssm, world_model_predict_discount, layer_norm, - encode_vit, decode_vit).to(device_type) + encode_vit, decode_vit, vit_l2_ratio).to(device_type) self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, actions_num if self.is_discrete else actions_num * 2, @@ -932,7 +947,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses['loss_discount_pred']) # for l in losses.values(): # world_model_loss += l - if self.world_model.decode_vit: + if self.world_model.decode_vit and self.world_model.vit_l2_ratio != 1.0: self.image_predictor_optimizer.zero_grad(set_to_none=True) self.scaler.scale(image_predictor_loss).backward() self.scaler.unscale_(self.image_predictor_optimizer) @@ -944,14 +959,11 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # FIXME: clip gradient should be parametrized self.scaler.unscale_(self.world_model_optimizer) - # for tag, value in self.world_model.named_parameters(): - # wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) metrics = defaultdict( lambda: torch.zeros(1).to(self.device)) - metrics |= wm_metrics with torch.cuda.amp.autocast(enabled=True): losses_ac = defaultdict( @@ -1016,9 +1028,6 @@ def calculate_entropy(dist): nn.utils.clip_grad_norm_(self.actor.parameters(), 100) nn.utils.clip_grad_norm_(self.critic.parameters(), 100) - # for tag, value in self.actor.named_parameters(): - # wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() - self.scaler.step(self.actor_optimizer) self.scaler.step(self.critic_optimizer) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 82e2c78..603c257 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -30,4 +30,5 @@ critic_soft_update_fraction: 1 discrete_rssm: false decode_vit: true +vit_l2_ratio: 1.0 encode_vit: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 63f6a94..7a073f7 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,6 +1,6 @@ defaults: - agent: dreamer_v2 - - env: dm_cartpole + - env: dm_quadruped - training: dm - _self_ @@ -9,21 +9,23 @@ device_type: cuda logger: type: tensorboard - message: Cartpole with convolutional decoder + message: Quadruped with 224px input inside ViT #message: test_last log_grads: false training: checkpoint_path: null steps: 1e6 - batch_size: 50 + batch_size: 16 val_logs_every: 2e4 - #val_logs_every: 2.5e3 validation: rollout_num: 5 visualize: true +#agents: +# batch_cluster_size: 10 + debug: profiler: false diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 150c83b..a6fb2d5 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -1,4 +1,6 @@ import random +import os +os.environ['MUJOCO_GL'] = 'egl' import crafter import hydra diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py index bcd24e3..87f36c1 100644 --- a/rl_sandbox/vision/dino.py +++ b/rl_sandbox/vision/dino.py @@ -276,15 +276,17 @@ def get_intermediate_layers(self, x, n=1): -def vit_small(patch_size=16, **kwargs): +def vit_small(patch_size=16, img_size=[224], **kwargs): model = VisionTransformer( + img_size=img_size, patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model -def vit_base(patch_size=16, **kwargs): +def vit_base(patch_size=16, img_size=[224], **kwargs): model = VisionTransformer( + img_size=img_size, patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model @@ -294,13 +296,13 @@ def vit_base(patch_size=16, **kwargs): class ViTFeat(nn.Module): """ Vision Transformer """ - def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16): + def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', patch_size=16, img_size=[224]): super().__init__() if vit_arch == 'base' : - self.model = vit_base(patch_size=patch_size, num_classes=0) + self.model = vit_base(patch_size=patch_size, num_classes=0, img_size=img_size) else : - self.model = vit_small(patch_size=patch_size, num_classes=0) + self.model = vit_small(patch_size=patch_size, num_classes=0, img_size=img_size) self.feat_dim = feat_dim self.vit_feat = vit_feat From 0a5ef3a955d5aa7b152f7a20f934edf87c4ada68 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Wed, 5 Apr 2023 10:44:47 +0000 Subject: [PATCH 053/106] Added 16 patch size with 224 px variant --- rl_sandbox/agents/dreamer_v2.py | 9 ++++++--- rl_sandbox/config/agent/dreamer_v2_crafter.yaml | 7 ++++--- rl_sandbox/config/config.yaml | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 7e75a52..cd4dd18 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -388,7 +388,9 @@ def forward(self, X): class ViTDecoder(nn.Module): - def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): + # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): + # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 5, 3]): + def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 3]): super().__init__() layers = [] self.channel_step = 12 @@ -461,7 +463,8 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss norm_layer=nn.Identity if layer_norm else nn.LayerNorm) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) @@ -607,7 +610,7 @@ def KL(dist1, dist2, free_nat = True): # (0.229, 0.224, 0.225)) inp = ToTensor(obs + 0.5) d_features = self.dino_vit(inp) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 28, 28)).float().mean()/2 + + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/2 + (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 8b93550..bd181db 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -8,19 +8,19 @@ rssm_dim: 1024 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 0.0 -world_model_lr: 3e-4 +world_model_lr: 2e-4 world_model_predict_discount: true # ActorCritic parameters discount_factor: 0.999 imagination_horizon: 15 -actor_lr: 3e-4 +actor_lr: 2e-4 # automatically chooses depending on discrete/continuous env actor_reinforce_fraction: null actor_entropy_scale: 3e-3 -critic_lr: 3e-4 +critic_lr: 2e-4 # Lambda parameter for trainin deeper multi-step prediction critic_value_target_lambda: 0.95 critic_update_interval: 100 @@ -29,4 +29,5 @@ critic_soft_update_fraction: 1 discrete_rssm: false decode_vit: true +vit_l2_ratio: 1.0 encode_vit: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 7a073f7..195cb71 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -9,7 +9,7 @@ device_type: cuda logger: type: tensorboard - message: Quadruped with 224px input inside ViT + message: Quadruped with 224px input inside ViT and 16 patch size #message: test_last log_grads: false From 3941234689501933d21576ee4feb25bde3243f6d Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 7 Apr 2023 21:37:43 +0100 Subject: [PATCH 054/106] Fixes to work with torch.compile --- rl_sandbox/agents/dreamer_v2.py | 66 +++++++++++-------------- rl_sandbox/config/agent/dreamer_v2.yaml | 2 +- rl_sandbox/config/config.yaml | 12 ++--- rl_sandbox/vision/dino.py | 12 ++--- 4 files changed, 41 insertions(+), 51 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index cd4dd18..daebf22 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -298,7 +298,7 @@ def estimate_stochastic_latent(self, prev_determ: torch.Tensor): # taking only one random between all ensembles # NOTE: in Dreamer ensemble_num is always 1 idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[idx] + return dists_per_model[0] def predict_next(self, prev_state: State, @@ -424,7 +424,7 @@ def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): self.momentum = momentum self.scale = scale self.eps= eps - self.mag = torch.ones(1, dtype=torch.float32) + self.register_buffer('mag', torch.ones(1, dtype=torch.float32)) self.mag.requires_grad = False def forward(self, x): @@ -432,7 +432,7 @@ def forward(self, x): return (x / (self.mag + self.eps))*self.scale def update(self, x): - self.mag = self.momentum * self.mag.to(x.device) + (1 - self.momentum) * (x.abs().mean()).detach() + self.mag = self.momentum * self.mag + (1 - self.momentum) * (x.abs().mean()).detach() class WorldModel(nn.Module): @@ -441,7 +441,7 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float): super().__init__() - self.kl_free_nats = kl_free_nats + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim self.latent_dim = latent_dim @@ -549,24 +549,31 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, d_c = discount.reshape(-1, self.cluster_size, 1) first_c = first.reshape(-1, self.cluster_size, 1) - losses = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) - metrics = defaultdict(lambda: torch.zeros(1).to(next(self.parameters()).device)) + losses = {} + metrics = {} - def KL(dist1, dist2, free_nat = True): + def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence - one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) - # TODO: kl_free_avg is used always - if free_nat: - kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) - kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) - else: - kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() - kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() + kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_rhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2), td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) priors = [] posteriors = [] + if self.decode_vit: + inp = obs + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(224, antialias=True)]) + # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)) + inp = ToTensor(obs + 0.5) + d_features = self.dino_vit(inp) + prev_state = self.get_initial_state(b // self.cluster_size) for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 @@ -579,7 +586,7 @@ def KL(dist1, dist2, free_nat = True): priors.append(prior) posteriors.append(posterior) - losses['loss_determ_recons'] += diff + # losses['loss_determ_recons'] += diff posterior = State.stack(posteriors) prior = State.stack(priors) @@ -587,7 +594,7 @@ def KL(dist1, dist2, free_nat = True): r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) - losses['loss_reconstruction_img'] = 0 + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) if not self.decode_vit: x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) @@ -601,16 +608,7 @@ def KL(dist1, dist2, free_nat = True): x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) - inp = obs - if not self.encode_vit: - ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)), - tv.transforms.Resize(224, antialias=True)]) - # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)) - inp = ToTensor(obs + 0.5) - d_features = self.dino_vit(inp) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/2 + + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits @@ -937,12 +935,10 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=True): - losses, discovered_states, wm_metrics = self.world_model.calculate_loss( - obs, a, r, discount_factors, first_flags) + losses, discovered_states, wm_metrics = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags) self.world_model.recurrent_model.discretizer_scheduler.step() # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device - world_model_loss = torch.Tensor(0).to(self.device) image_predictor_loss = losses['loss_reconstruction_img'] world_model_loss = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + @@ -965,12 +961,10 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) - metrics = defaultdict( - lambda: torch.zeros(1).to(self.device)) + metrics = {} with torch.cuda.amp.autocast(enabled=True): - losses_ac = defaultdict( - lambda: torch.zeros(1).to(self.device)) + losses_ac = {} initial_states = State(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_logits.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_.flatten(0, 1).unsqueeze(0).detach()) @@ -1000,14 +994,14 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation action_dists = self.actor(zs[:-2].detach()) baseline = self.critic.target_critic(zs[:-2]).mode advantage = (vs[1:] - baseline).detach() - losses_ac['loss_actor_reinforce'] += -(self.rho * action_dists.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() + losses_ac['loss_actor_reinforce'] = -(self.rho * action_dists.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-2])).mean() def calculate_entropy(dist): return dist.entropy().unsqueeze(2) # return dist.base_dist.base_dist.entropy().unsqueeze(2) - losses_ac['loss_actor_entropy'] += -(self.eta * calculate_entropy(action_dists)*discount_factors[:-2]).mean() + losses_ac['loss_actor_entropy'] = -(self.eta * calculate_entropy(action_dists)*discount_factors[:-2]).mean() losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 603c257..a36ca9b 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -7,7 +7,7 @@ latent_classes: 32 rssm_dim: 200 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 -kl_loss_free_nats: 0.25 +kl_loss_free_nats: 1.0 world_model_lr: 3e-4 world_model_predict_discount: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 195cb71..c029dd1 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,6 +1,6 @@ defaults: - agent: dreamer_v2 - - env: dm_quadruped + - env: dm_acrobot - training: dm - _self_ @@ -9,14 +9,13 @@ device_type: cuda logger: type: tensorboard - message: Quadruped with 224px input inside ViT and 16 patch size - #message: test_last + message: Acrobot reference with 1 nat log_grads: false training: checkpoint_path: null - steps: 1e6 - batch_size: 16 + steps: 1.5e6 + batch_size: 50 val_logs_every: 2e4 @@ -24,8 +23,5 @@ validation: rollout_num: 5 visualize: true -#agents: -# batch_cluster_size: 10 - debug: profiler: false diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py index 87f36c1..5952379 100644 --- a/rl_sandbox/vision/dino.py +++ b/rl_sandbox/vision/dino.py @@ -131,7 +131,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) - return x, attn + return x, qkv, attn class Block(nn.Module): @@ -147,9 +147,9 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, return_attention=False): - y, attn = self.attn(self.norm1(x)) + y, qkv, attn = self.attn(self.norm1(x)) if return_attention: - return attn + return qkv, attn x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x @@ -319,17 +319,17 @@ def forward(self, img) : def hook_fn_forward_qkv(module, input, output): feat_out["qkv"] = output - self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) + # self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) # Forward pass in the model with torch.no_grad() : h, w = img.shape[2], img.shape[3] feat_h, feat_w = h // self.patch_size, w // self.patch_size - attentions = self.model.get_last_selfattention(img) + qkv, attentions = self.model.get_last_selfattention(img) bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] qkv = ( - feat_out["qkv"] + qkv .reshape(bs, nb_token, 3, nb_head, -1) .permute(2, 0, 3, 1, 4) ) From f2c090605aba035a925cc55ea655fb8356c61f0d Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 7 Apr 2023 21:38:27 +0100 Subject: [PATCH 055/106] Working slot attention implementation, with 224px resize --- rl_sandbox/vision/my_slot_attention.py | 123 +++++++++++++++---------- 1 file changed, 74 insertions(+), 49 deletions(-) diff --git a/rl_sandbox/vision/my_slot_attention.py b/rl_sandbox/vision/my_slot_attention.py index 4a4ffd5..977e932 100644 --- a/rl_sandbox/vision/my_slot_attention.py +++ b/rl_sandbox/vision/my_slot_attention.py @@ -8,10 +8,10 @@ from tqdm import tqdm import numpy as np -from rl_sandbox.vision.vae import ResBlock +from rl_sandbox.vision.dino import ViTFeat class SlotAttention(nn.Module): - def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int): + def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int, mlp_hidden: int = 128): super().__init__() self.seq_num = seq_num @@ -21,32 +21,35 @@ def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int): self.scale = self.n_dim**(-1/2) self.epsilon = 1e-8 - # self.norm = LayerNorm() self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) nn.init.xavier_uniform_(self.slots_logsigma) self.slots_proj = nn.Linear(n_dim, n_dim) - self.slots_proj_2 = nn.Linear(n_dim, n_dim) + self.slots_proj_2 = nn.Sequential( + nn.Linear(n_dim, mlp_hidden), + nn.ReLU(inplace=True), + nn.Linear(mlp_hidden, n_dim), + ) self.slots_norm = nn.LayerNorm(self.n_dim) self.slots_norm_2 = nn.LayerNorm(self.n_dim) self.slots_reccur = nn.GRUCell(input_size=self.n_dim, hidden_size=self.n_dim) - self.inputs_proj = nn.Linear(n_dim, n_dim*2, ) + self.inputs_proj = nn.Linear(n_dim, n_dim*2) self.inputs_norm = nn.LayerNorm(self.n_dim) - def forward(self, X: Float[torch.Tensor, 'batch seq h w']) -> Float[torch.Tensor, 'batch num_slots n_dim']: - batch, seq, _, _ = X.shape - slots = self.slots_logsigma.exp() + self.slots_mu*torch.randn(batch, self.n_slots, self.n_dim, device=X.device) + def forward(self, X: Float[torch.Tensor, 'batch seq n_dim']) -> Float[torch.Tensor, 'batch num_slots n_dim']: + batch, _, _ = X.shape + k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) + + slots = self.slots_mu + self.slots_logsigma.exp() * torch.randn(batch, self.n_slots, self.n_dim, device=X.device) for _ in range(self.n_iter): slots_prev = slots - k, v = self.inputs_proj(self.inputs_norm(X).permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)).chunk(2, dim=-1) slots = self.slots_norm(slots) q = self.slots_proj(slots) attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k), dim=1) + self.epsilon - attn = attn / attn.sum(dim=-1, keepdim=True) updates = torch.einsum('bij,bjk->bik', attn, v) / self.n_slots @@ -65,58 +68,70 @@ def build_grid(resolution): class PositionalEmbedding(nn.Module): - def __init__(self, n_dim: int): + def __init__(self, n_dim: int, res: t.Tuple[int, int]): super().__init__() self.n_dim = n_dim self.proj = nn.Linear(4, n_dim) - self.grid = torch.from_numpy(build_grid((64, 64))) + self.register_buffer('grid', torch.from_numpy(build_grid(res))) def forward(self, X) -> torch.Tensor: - return X + self.proj(self.grid.to(X.device)) - # xs = torch.arange(self.n_dim) - # pos = torch.zeros_like(xs, dtype=torch.float) - # for i in range(self.n_dim): - # pos += torch.sin(xs / (1e4**(2*i/self.n_dim))) - # return X + pos + return X + self.proj(self.grid).permute(0, 3, 1, 2) class SlottedAutoEncoder(nn.Module): def __init__(self, num_slots: int, n_iter: int): super().__init__() in_channels = 3 latent_dim = 16 + self.n_dim = 196 + self.lat_dim = int(self.n_dim**0.5) self.encoder = nn.Sequential( - nn.Conv2d(in_channels, 64, kernel_size=5, padding='same'), - nn.LeakyReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, padding='same'), - nn.LeakyReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, padding='same'), - nn.LeakyReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, padding='same'), - nn.LeakyReLU(inplace=True), + nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, self.n_dim, kernel_size=5, padding='same'), + nn.ReLU(inplace=True), + ) + + self.mlp = nn.Sequential( + nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim) ) + seq_num = latent_dim - n_dim = 64 - self.positional_augmenter = PositionalEmbedding(n_dim) - self.slot_attention = SlotAttention(num_slots, seq_num, n_dim, n_iter) + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (13, 13)) + self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) + self.slot_attention = SlotAttention(num_slots, seq_num, self.n_dim, n_iter) self.decoder = nn.Sequential( + nn.ConvTranspose2d(self.n_dim, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.LeakyReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.LeakyReLU(inplace=True), + nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.LeakyReLU(inplace=True), + nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1), ) def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch.Tensor, 'batch 3 h w'], Float[torch.Tensor, 'batch num_slots 4 h w']]: - features = self.encoder(X) # -> batch D 64 - features_with_pos_enc = self.positional_augmenter(features) # -> batch D 64 + features = self.encoder(X) # -> batch D h w + # vit_features = self.dino_vit(X) + features_with_pos_enc = self.positional_augmenter_inp(features) # -> batch D h w + + batch, seq, _, _ = X.shape + pre_slot_features = self.mlp(features_with_pos_enc.permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)) - slots = self.slot_attention(features_with_pos_enc) # -> batch num_slots 64 - slots = slots.flatten(0, 1).reshape(-1, 1, 1, 64) - slots = slots.repeat((1, 8, 8, 1)) + slots = self.slot_attention(pre_slot_features) # -> batch num_slots D + slots = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim) + slots = slots.repeat((1, self.lat_dim, self.lat_dim, 1)).permute(0, 3, 1, 2) # -> batch*num_slots D sqrt(D) sqrt(D) + slots_with_pos_enc = self.positional_augmenter_dec(slots) - decoded_imgs, masks = self.decoder(slots.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).reshape(X.shape[0], -1, *X.shape[2:], 4).split([3, 1], dim=-1) + decoded_imgs, masks = self.decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(X.shape[0], -1, *(np.array(X.shape[2:])//2), 4).split([3, 1], dim=-1) decoded_imgs = decoded_imgs * F.softmax(masks, dim=1) rec_img = torch.sum(decoded_imgs, dim=1) @@ -125,17 +140,25 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. if __name__ == '__main__': device = 'cuda' ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), - tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ] - ) + tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + tv.transforms.Resize(224, antialias=True), + ]) train_data = tv.datasets.ImageFolder('~/rl_old/rl_sandbox/crafter_data_2/', transform=ToTensor) + # train_data_loader = torch.utils.data.DataLoader(train_data, + # batch_size=32, + # shuffle=True, + # num_workers=8) + train_data_loader = torch.utils.data.DataLoader(train_data, - batch_size=64, - shuffle=True, - num_workers=8) + batch_size=4, + prefetch_factor=1, + shuffle=False, + num_workers=2) import socket from datetime import datetime current_time = datetime.now().strftime("%b%d_%H-%M-%S") - logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}", comment="Added lr scheduler") + comment = "Reconstruct only 112px for faster calc".replace(" ", "_") + # logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}_{comment}") number_of_slots = 7 slots_iter_num = 3 @@ -146,6 +169,7 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. decay_steps = 1e5 model = SlottedAutoEncoder(number_of_slots, slots_iter_num).to(device) + # opt_model = torch.compile(model, mode='auto-maxtune') optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) @@ -156,13 +180,10 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. epoch = 0 pbar = tqdm(total=total_steps, desc='Training') while global_step < total_steps: - epoch += 1 - logger.add_scalar('epoch', epoch, epoch) - for sample_num, (img, target) in enumerate(train_data_loader): recovered_img, _ = model(img.to(device)) - reg_loss = F.mse_loss(img.to(device), recovered_img) + reg_loss = F.mse_loss(img.to(device)[:, :, ::2, ::2], recovered_img) optimizer.zero_grad() reg_loss.backward() @@ -183,3 +204,7 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. logger.add_image(f'val/example_image_rec', unnormalize(recovered_img.cpu().detach()[0]), epoch*3 + i) for i in range(6): logger.add_image(f'val/example_image_rec_{i}', unnormalize(imgs_per_slot.cpu().detach()[0][i]), epoch*3 + i) + + logger.add_scalar('epoch', epoch, epoch) + epoch += 1 + From 9d2839f0c8a5c73e994a5bbcab16bcaa4b668d5d Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 11 Apr 2023 11:29:38 +0100 Subject: [PATCH 056/106] Added DINO features reconstruction for slot attention --- rl_sandbox/agents/dreamer_v2.py | 2 +- .../config/agent/dreamer_v2_crafter.yaml | 2 +- rl_sandbox/utils/logger.py | 8 +- rl_sandbox/vision/my_slot_attention.py | 124 +++++++++++++----- 4 files changed, 98 insertions(+), 38 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index daebf22..bb5fa9b 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -946,7 +946,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation losses['loss_discount_pred']) # for l in losses.values(): # world_model_loss += l - if self.world_model.decode_vit and self.world_model.vit_l2_ratio != 1.0: + if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: self.image_predictor_optimizer.zero_grad(set_to_none=True) self.scaler.scale(image_predictor_loss).backward() self.scaler.unscale_(self.image_predictor_optimizer) diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index bd181db..0f5423c 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -29,5 +29,5 @@ critic_soft_update_fraction: 1 discrete_rssm: false decode_vit: true -vit_l2_ratio: 1.0 +vit_l2_ratio: 0.5 encode_vit: false diff --git a/rl_sandbox/utils/logger.py b/rl_sandbox/utils/logger.py index c263049..c3981d8 100644 --- a/rl_sandbox/utils/logger.py +++ b/rl_sandbox/utils/logger.py @@ -25,11 +25,13 @@ def add_figure(*args, **kwargs): class Logger: def __init__(self, type: t.Optional[str], message: t.Optional[str] = None, - log_grads: bool = True) -> None: + log_grads: bool = True, + log_dir: t.Optional[str] = None + ) -> None: self.type = type match type: case "tensorboard": - self.writer = SummaryWriter(comment=message or "") + self.writer = SummaryWriter(comment=message or "", log_dir=log_dir) case None: self.writer = SummaryWriterMock() case _: @@ -49,7 +51,7 @@ def add_scalar(self, name: str, value: t.Any, global_step: int): self.writer.add_scalar(name, value, global_step) def add_image(self, name: str, image: t.Any, global_step: int): - self.writer.add_image(name, image, global_step, dataformats='HW') + self.writer.add_image(name, image, global_step) def add_video(self, name: str, video: t.Any, global_step: int): self.writer.add_video(name, video, global_step, fps=20) diff --git a/rl_sandbox/vision/my_slot_attention.py b/rl_sandbox/vision/my_slot_attention.py index 977e932..e902020 100644 --- a/rl_sandbox/vision/my_slot_attention.py +++ b/rl_sandbox/vision/my_slot_attention.py @@ -9,9 +9,10 @@ import numpy as np from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.utils.logger import Logger class SlotAttention(nn.Module): - def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int, mlp_hidden: int = 128): + def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int): super().__init__() self.seq_num = seq_num @@ -27,9 +28,9 @@ def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int, mlp_hi self.slots_proj = nn.Linear(n_dim, n_dim) self.slots_proj_2 = nn.Sequential( - nn.Linear(n_dim, mlp_hidden), + nn.Linear(n_dim, n_dim*4), nn.ReLU(inplace=True), - nn.Linear(mlp_hidden, n_dim), + nn.Linear(n_dim*4, n_dim), ) self.slots_norm = nn.LayerNorm(self.n_dim) self.slots_norm_2 = nn.LayerNorm(self.n_dim) @@ -107,8 +108,9 @@ def __init__(self, num_slots: int, n_iter: int): self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (13, 13)) self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) + self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (14, 14)) self.slot_attention = SlotAttention(num_slots, seq_num, self.n_dim, n_iter) - self.decoder = nn.Sequential( + self.img_decoder = nn.Sequential( # Dx14x14 -> (3+1)x112x112 nn.ConvTranspose2d(self.n_dim, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), @@ -118,47 +120,86 @@ def __init__(self, num_slots: int, n_iter: int): nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1), ) + self.vit_decoder = nn.Sequential( # Dx1x1 -> (384+1)x14x14 + nn.ConvTranspose2d(self.n_dim, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, 384, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(384, 576, kernel_size=3, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(576, 385, kernel_size=3, stride=(1, 1), padding=1), + ) + + # self.vit_decoder_mlp = nn.Sequential( + # nn.Linear(self.n_dim, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 385), + # nn.ReLU(inplace=True) + # ) + def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch.Tensor, 'batch 3 h w'], Float[torch.Tensor, 'batch num_slots 4 h w']]: features = self.encoder(X) # -> batch D h w - # vit_features = self.dino_vit(X) features_with_pos_enc = self.positional_augmenter_inp(features) # -> batch D h w batch, seq, _, _ = X.shape + vit_features = self.dino_vit(X) + vit_res_num = int(vit_features.shape[-1]**0.5) + vit_features = vit_features.reshape(batch, -1, vit_res_num, vit_res_num) + pre_slot_features = self.mlp(features_with_pos_enc.permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)) slots = self.slot_attention(pre_slot_features) # -> batch num_slots D - slots = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim) - slots = slots.repeat((1, self.lat_dim, self.lat_dim, 1)).permute(0, 3, 1, 2) # -> batch*num_slots D sqrt(D) sqrt(D) + slots = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim).permute(0, 3, 1, 2) + + # slots_with_vit_pos_enc = self.positional_augmenter_vit_dec(slots.flatten(2, 3).repeat((1, 1, 196)).reshape(-1, self.n_dim, 14, 14)).flatten(2, 3) + # decoded_features, vit_masks =self.vit_decoder_mlp(slots_with_vit_pos_enc).reshape(batch, -1, vit_res_num, vit_res_num, 385).split([384, 1], dim=-1) + + decoded_features, vit_masks = self.vit_decoder(slots).permute(0, 2, 3, 1).reshape(batch, -1, vit_res_num, vit_res_num, 385).split([384, 1], dim=-1) + vit_mask = F.softmax(vit_masks, dim=1) + + rec_features = (decoded_features * vit_mask).sum(dim=1) + + slots = slots.repeat((1, 1, self.lat_dim, self.lat_dim)) # -> batch*num_slots D sqrt(D) sqrt(D) slots_with_pos_enc = self.positional_augmenter_dec(slots) - decoded_imgs, masks = self.decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(X.shape[0], -1, *(np.array(X.shape[2:])//2), 4).split([3, 1], dim=-1) + decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *(np.array(X.shape[2:])//2), 4).split([3, 1], dim=-1) + img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask rec_img = torch.sum(decoded_imgs, dim=1) - return rec_img.permute(0, 3, 1, 2), decoded_imgs.permute(0, 1, 4, 2, 3) + return rec_img.permute(0, 3, 1, 2), decoded_imgs.permute(0, 1, 4, 2, 3), F.mse_loss(rec_features.permute(0, 3, 1, 2), vit_features), vit_mask if __name__ == '__main__': device = 'cuda' + debug = False ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), tv.transforms.Resize(224, antialias=True), ]) - train_data = tv.datasets.ImageFolder('~/rl_old/rl_sandbox/crafter_data_2/', transform=ToTensor) - # train_data_loader = torch.utils.data.DataLoader(train_data, - # batch_size=32, - # shuffle=True, - # num_workers=8) - - train_data_loader = torch.utils.data.DataLoader(train_data, - batch_size=4, - prefetch_factor=1, - shuffle=False, - num_workers=2) + train_data = tv.datasets.ImageFolder('~/rl_sandbox/crafter_data/', transform=ToTensor) + if debug: + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=4, + prefetch_factor=1, + shuffle=False, + num_workers=2) + else: + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=32, + shuffle=True, + num_workers=8) + import socket from datetime import datetime current_time = datetime.now().strftime("%b%d_%H-%M-%S") - comment = "Reconstruct only 112px for faster calc".replace(" ", "_") - # logger = SummaryWriter(log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}_{comment}") + comment = "Added vit masks logging, lambda=0.1, return old dino".replace(" ", "_") + logger = Logger(None if debug else 'tensorboard', message=comment, log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}_{comment}") number_of_slots = 7 slots_iter_num = 3 @@ -167,9 +208,10 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. warmup_steps = 1e4 decay_rate = 0.5 decay_steps = 1e5 + val_every = 1e4 model = SlottedAutoEncoder(number_of_slots, slots_iter_num).to(device) - # opt_model = torch.compile(model, mode='auto-maxtune') + # model = torch.compile(model) optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) @@ -177,34 +219,50 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) global_step = 0 + prev_global_step = 0 epoch = 0 pbar = tqdm(total=total_steps, desc='Training') while global_step < total_steps: for sample_num, (img, target) in enumerate(train_data_loader): - recovered_img, _ = model(img.to(device)) + recovered_img, _, vit_rec_loss, _ = model(img.to(device)) reg_loss = F.mse_loss(img.to(device)[:, :, ::2, ::2], recovered_img) + lambda_ = 0.1 + loss = lambda_ * reg_loss + (1 - lambda_) * vit_rec_loss + optimizer.zero_grad() - reg_loss.backward() + loss.backward() optimizer.step() lr_scheduler.step() - logger.add_scalar('train/loss', reg_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/img_rec_loss', reg_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/vit_rec_loss', vit_rec_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/loss', loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) pbar.update(1) + global_step += len(train_data_loader) + + epoch += 1 + logger.add_scalar('epoch', epoch, epoch) + + if global_step - prev_global_step > val_every: + prev_global_step = global_step + else: + continue for i in range(3): img, target = next(iter(train_data_loader)) - recovered_img, imgs_per_slot = model(img.to(device)) + recovered_img, imgs_per_slot, _, vit_mask = model(img.to(device)) + upscale = tv.transforms.Resize(224, antialias=True) unnormalize = tv.transforms.Compose([ tv.transforms.Normalize((0, 0, 0), (1/0.229, 1/0.224, 1/0.225)), tv.transforms.Normalize((-0.485, -0.456, -0.406), (1., 1., 1.)) ]) logger.add_image(f'val/example_image', unnormalize(img.cpu().detach()[0]), epoch*3 + i) logger.add_image(f'val/example_image_rec', unnormalize(recovered_img.cpu().detach()[0]), epoch*3 + i) - for i in range(6): - logger.add_image(f'val/example_image_rec_{i}', unnormalize(imgs_per_slot.cpu().detach()[0][i]), epoch*3 + i) - - logger.add_scalar('epoch', epoch, epoch) - epoch += 1 + per_slot_img = unnormalize(imgs_per_slot.cpu().detach())[0].permute((1, 2, 0, 3)).flatten(2, 3) + logger.add_image(f'val/example_image_slot_rec', per_slot_img, epoch*3 + i) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * img.to(device).unsqueeze(1))[0].permute(1, 2, 0, 3).flatten(2, 3) + logger.add_image(f'val/example_vit_slot_mask', per_slot_vit/upscaled_mask.max(), epoch*3 + i) From f060b5a4c2ecac6f45addd9535e1aaf01e6eb200 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 11 Apr 2023 18:12:57 +0100 Subject: [PATCH 057/106] 64px encoder, add possibility to use prev_slots --- rl_sandbox/vision/my_slot_attention.py | 84 ++++++++++++++------------ 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/rl_sandbox/vision/my_slot_attention.py b/rl_sandbox/vision/my_slot_attention.py index e902020..b83bc11 100644 --- a/rl_sandbox/vision/my_slot_attention.py +++ b/rl_sandbox/vision/my_slot_attention.py @@ -4,7 +4,6 @@ import torch.nn.functional as F from jaxtyping import Float import torchvision as tv -from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import numpy as np @@ -12,10 +11,9 @@ from rl_sandbox.utils.logger import Logger class SlotAttention(nn.Module): - def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int): + def __init__(self, num_slots: int, n_dim: int, n_iter: int): super().__init__() - self.seq_num = seq_num self.n_slots = num_slots self.n_iter = n_iter self.n_dim = n_dim @@ -39,11 +37,14 @@ def __init__(self, num_slots: int, seq_num: int, n_dim: int, n_iter: int): self.inputs_proj = nn.Linear(n_dim, n_dim*2) self.inputs_norm = nn.LayerNorm(self.n_dim) - def forward(self, X: Float[torch.Tensor, 'batch seq n_dim']) -> Float[torch.Tensor, 'batch num_slots n_dim']: + def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']]) -> Float[torch.Tensor, 'batch num_slots n_dim']: batch, _, _ = X.shape k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) - slots = self.slots_mu + self.slots_logsigma.exp() * torch.randn(batch, self.n_slots, self.n_dim, device=X.device) + if prev_slots is None: + slots = self.slots_mu + self.slots_logsigma.exp() * torch.randn(batch, self.n_slots, self.n_dim, device=X.device) + else: + slots = prev_slots for _ in range(self.n_iter): slots_prev = slots @@ -79,11 +80,10 @@ def forward(self, X) -> torch.Tensor: return X + self.proj(self.grid).permute(0, 3, 1, 2) class SlottedAutoEncoder(nn.Module): - def __init__(self, num_slots: int, n_iter: int): + def __init__(self, num_slots: int, n_iter: int, dino_inp_size: int = 224): super().__init__() in_channels = 3 - latent_dim = 16 - self.n_dim = 196 + self.n_dim = 128 self.lat_dim = int(self.n_dim**0.5) self.encoder = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=1), @@ -92,8 +92,6 @@ def __init__(self, num_slots: int, n_iter: int): nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), - nn.ReLU(inplace=True), nn.Conv2d(64, self.n_dim, kernel_size=5, padding='same'), nn.ReLU(inplace=True), ) @@ -104,13 +102,16 @@ def __init__(self, num_slots: int, n_iter: int): nn.Linear(self.n_dim, self.n_dim) ) - seq_num = latent_dim + self.dino_inp_size = dino_inp_size self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (13, 13)) - self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) - self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (14, 14)) - self.slot_attention = SlotAttention(num_slots, seq_num, self.n_dim, n_iter) - self.img_decoder = nn.Sequential( # Dx14x14 -> (3+1)x112x112 + self.vit_patch_num = self.dino_inp_size // self.dino_vit.patch_size + self.vit_feat = self.dino_vit.feat_dim + + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (7, 7)) + self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) + self.slot_attention = SlotAttention(num_slots, self.n_dim, n_iter) + self.img_decoder = nn.Sequential( # Dx8x8 -> (3+1)x64x64 nn.ConvTranspose2d(self.n_dim, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), @@ -125,11 +126,11 @@ def __init__(self, num_slots: int, n_iter: int): nn.ReLU(inplace=True), nn.ConvTranspose2d(192, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(192, 384, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(192, self.vit_feat, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(384, 576, kernel_size=3, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(self.vit_feat, self.vit_feat*2, kernel_size=3, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(576, 385, kernel_size=3, stride=(1, 1), padding=1), + nn.ConvTranspose2d(self.vit_feat*2, self.vit_feat+1, kernel_size=3, stride=(1, 1), padding=1), ) # self.vit_decoder_mlp = nn.Sequential( @@ -139,49 +140,56 @@ def __init__(self, num_slots: int, n_iter: int): # nn.ReLU(inplace=True), # nn.Linear(1024, 1024), # nn.ReLU(inplace=True), - # nn.Linear(1024, 385), + # nn.Linear(1024, self.vit_feat+1), # nn.ReLU(inplace=True) # ) - def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch.Tensor, 'batch 3 h w'], Float[torch.Tensor, 'batch num_slots 4 h w']]: + def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']] = None) -> t.Dict[str, torch.Tensor]: features = self.encoder(X) # -> batch D h w features_with_pos_enc = self.positional_augmenter_inp(features) # -> batch D h w + resize = tv.transforms.Resize(self.dino_inp_size, antialias=True) + batch, seq, _, _ = X.shape - vit_features = self.dino_vit(X) - vit_res_num = int(vit_features.shape[-1]**0.5) - vit_features = vit_features.reshape(batch, -1, vit_res_num, vit_res_num) + vit_features = self.dino_vit(resize(X)) + vit_features = vit_features.reshape(batch, -1, self.vit_patch_num, self.vit_patch_num) pre_slot_features = self.mlp(features_with_pos_enc.permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)) - slots = self.slot_attention(pre_slot_features) # -> batch num_slots D - slots = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim).permute(0, 3, 1, 2) + slots = self.slot_attention(pre_slot_features, prev_slots) # -> batch num_slots D + slots_grid = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim).permute(0, 3, 1, 2) - # slots_with_vit_pos_enc = self.positional_augmenter_vit_dec(slots.flatten(2, 3).repeat((1, 1, 196)).reshape(-1, self.n_dim, 14, 14)).flatten(2, 3) - # decoded_features, vit_masks =self.vit_decoder_mlp(slots_with_vit_pos_enc).reshape(batch, -1, vit_res_num, vit_res_num, 385).split([384, 1], dim=-1) + # slots_with_vit_pos_enc = self.positional_augmenter_vit_dec(slots_grid.flatten(2, 3).repeat((1, 1, 196)).reshape(-1, self.n_dim, self.lat_dim, self.lat_dim)).flatten(2, 3) + # decoded_features, vit_masks =self.vit_decoder_mlp(slots_with_vit_pos_enc).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) - decoded_features, vit_masks = self.vit_decoder(slots).permute(0, 2, 3, 1).reshape(batch, -1, vit_res_num, vit_res_num, 385).split([384, 1], dim=-1) + decoded_features, vit_masks = self.vit_decoder(slots_grid).permute(0, 2, 3, 1).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) vit_mask = F.softmax(vit_masks, dim=1) rec_features = (decoded_features * vit_mask).sum(dim=1) - slots = slots.repeat((1, 1, self.lat_dim, self.lat_dim)) # -> batch*num_slots D sqrt(D) sqrt(D) - slots_with_pos_enc = self.positional_augmenter_dec(slots) + slots_grid = slots_grid.repeat((1, 1, 8, 8)) # -> batch*num_slots D sqrt(D) sqrt(D) + slots_with_pos_enc = self.positional_augmenter_dec(slots_grid) decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *(np.array(X.shape[2:])//2), 4).split([3, 1], dim=-1) img_mask = F.softmax(masks, dim=1) decoded_imgs = decoded_imgs * img_mask rec_img = torch.sum(decoded_imgs, dim=1) - return rec_img.permute(0, 3, 1, 2), decoded_imgs.permute(0, 1, 4, 2, 3), F.mse_loss(rec_features.permute(0, 3, 1, 2), vit_features), vit_mask + return { + 'rec_img': rec_img.permute(0, 3, 1, 2), + 'img_per_slot': decoded_imgs.permute(0, 1, 4, 2, 3), + 'vit_mask': vit_mask, + 'vit_rec_loss': F.mse_loss(rec_features.permute(0, 3, 1, 2), vit_features), + 'slots': slots + } if __name__ == '__main__': device = 'cuda' - debug = False + debug = True ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - tv.transforms.Resize(224, antialias=True), ]) + train_data = tv.datasets.ImageFolder('~/rl_sandbox/crafter_data/', transform=ToTensor) if debug: train_data_loader = torch.utils.data.DataLoader(train_data, @@ -224,7 +232,8 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. pbar = tqdm(total=total_steps, desc='Training') while global_step < total_steps: for sample_num, (img, target) in enumerate(train_data_loader): - recovered_img, _, vit_rec_loss, _ = model(img.to(device)) + res = model(img.to(device)) + recovered_img, vit_rec_loss = res['rec_img'], res['vit_rec_loss'] reg_loss = F.mse_loss(img.to(device)[:, :, ::2, ::2], recovered_img) @@ -252,8 +261,9 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w']) -> t.Tuple[Float[torch. for i in range(3): img, target = next(iter(train_data_loader)) - recovered_img, imgs_per_slot, _, vit_mask = model(img.to(device)) - upscale = tv.transforms.Resize(224, antialias=True) + res = model(img.to(device)) + recovered_img, imgs_per_slot, vit_mask = res['rec_img'], res['img_per_slot'], res['vit_mask'] + upscale = tv.transforms.Resize(64, antialias=True) unnormalize = tv.transforms.Compose([ tv.transforms.Normalize((0, 0, 0), (1/0.229, 1/0.224, 1/0.225)), tv.transforms.Normalize((-0.485, -0.456, -0.406), (1., 1., 1.)) From 5034573b5e228654ed4b7a42aeba4fb80f72e701 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sat, 15 Apr 2023 10:17:58 +0000 Subject: [PATCH 058/106] Added first implementation of slot attention inside Dreamer --- rl_sandbox/agents/dreamer_v2.py | 265 ++++++++++++------------ rl_sandbox/config/agent/dreamer_v2.yaml | 11 +- rl_sandbox/config/config.yaml | 18 +- rl_sandbox/config/training/dm.yaml | 2 +- rl_sandbox/utils/dists.py | 79 +++++++ rl_sandbox/utils/logger.py | 4 +- rl_sandbox/vision/dino.py | 12 +- rl_sandbox/vision/my_slot_attention.py | 14 +- 8 files changed, 248 insertions(+), 157 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index bb5fa9b..64933fc 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -19,8 +19,9 @@ from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, Observations, Rewards, TerminationFlags, IsFirstFlags) -from rl_sandbox.utils.dists import TruncatedNormal from rl_sandbox.utils.schedulers import LinearScheduler +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.vision.my_slot_attention import SlotAttention, PositionalEmbedding class View(nn.Module): @@ -31,96 +32,27 @@ def __init__(self, shape): def forward(self, x): return x.view(*self.shape) - -class Sigmoid2(nn.Module): - def forward(self, x): - return 2*torch.sigmoid(x/2) - -class NormalWithOffset(nn.Module): - def __init__(self, min_std: float, std_trans: str = 'sigmoid2', transform: t.Optional[str] = None): - super().__init__() - self.min_std = min_std - match std_trans: - case 'identity': - self.std_trans = nn.Identity() - case 'softplus': - self.std_trans = nn.Softplus() - case 'sigmoid': - self.std_trans = nn.Sigmoid() - case 'sigmoid2': - self.std_trans = Sigmoid2() - case _: - raise RuntimeError("Unknown std transformation") - - match transform: - case 'tanh': - self.trans = [td.TanhTransform(cache_size=1)] - case None: - self.trans = None - case _: - raise RuntimeError("Unknown distribution transformation") - - def forward(self, x): - mean, std = x.chunk(2, dim=-1) - dist = td.Normal(mean, self.std_trans(std) + self.min_std) - if self.trans is None: - return dist - else: - return td.TransformedDistribution(dist, self.trans) - -class DistLayer(nn.Module): - def __init__(self, type: str): - super().__init__() - self._dist = type - match type: - case 'mse': - self.dist = lambda x: td.Normal(x.float(), 1.0) - case 'normal': - self.dist = NormalWithOffset(min_std=0.1) - case 'onehot': - # Forcing float32 on AMP - self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x.float()) - case 'normal_tanh': - def get_tanh_normal(x, min_std=0.1): - mean, std = x.chunk(2, dim=-1) - init_std = np.log(np.exp(5) - 1) - raise NotImplementedError() - # return TanhNormal(torch.clamp(mean, -9.0, 9.0).float(), (F.softplus(std + init_std) + min_std).float(), upscale=5) - self.dist = get_tanh_normal - case 'normal_trunc': - def get_trunc_normal(x, min_std=0.1): - mean, std = x.chunk(2, dim=-1) - return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float(), a=-1, b=1) - self.dist = get_trunc_normal - case 'binary': - self.dist = lambda x: td.Bernoulli(logits=x) - case _: - raise RuntimeError("Invalid dist layer") - - def forward(self, x): - match self._dist: - case 'onehot': - return self.dist(x) - case _: - return td.Independent(self.dist(x), 1) - def Dist(val): return DistLayer('onehot')(val) @dataclass class State: - determ: Float[torch.Tensor, 'seq batch determ'] - stoch_logits: Float[torch.Tensor, 'seq batch latent_classes latent_dim'] - stoch_: t.Optional[Bool[torch.Tensor, 'seq batch stoch_dim']] = None + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None @property def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + + @property + def combined_slots(self): return torch.concat([self.determ, self.stoch], dim=-1) @property def stoch(self): if self.stoch_ is None: - self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:2] + (-1,)) + self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1,)) return self.stoch_ @property @@ -279,6 +211,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret # For observation we do not have ensemble # FIXME: very bad magic number img_sz = 4 * 384 # 384*2x2 + # img_sz = 192 self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' @@ -303,24 +236,26 @@ def estimate_stochastic_latent(self, prev_determ: torch.Tensor): def predict_next(self, prev_state: State, action) -> State: - x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action], dim=-1)) + x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1))], dim=-1)) # NOTE: x and determ are actually the same value if sequence of 1 is inserted - x, determ_prior = self.determ_recurrent(x, prev_state.determ) + x, determ_prior = self.determ_recurrent(x.flatten(1, 2), prev_state.determ.flatten(1, 2)) if self.discrete_rssm: - determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) - determ_post = determ_post.reshape(determ_prior.shape) - determ_post = self.determ_layer_norm(determ_post) - alpha = self.discretizer_scheduler.val - determ_post = alpha * determ_prior + (1-alpha) * determ_post + raise NotImplementedError("discrete rssm was not adopted for slot attention") + # determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + # determ_post = determ_post.reshape(determ_prior.shape) + # determ_post = self.determ_layer_norm(determ_post) + # alpha = self.discretizer_scheduler.val + # determ_post = alpha * determ_prior + (1-alpha) * determ_post else: determ_post, diff = determ_prior, 0 # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(x) - return State(determ_post, predicted_stoch_logits), diff + # Size is 1 x B x slots_num x ... + return State(determ_post.reshape(prev_state.determ.shape), predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' - return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1))) + return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten(1, 2).reshape(prior.stoch_logits.shape)) def forward(self, h_prev: State, embed, action) -> tuple[State, State]: @@ -338,7 +273,7 @@ def forward(self, h_prev: State, embed, class Encoder(nn.Module): - def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4, 4]): + def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4]): super().__init__() layers = [] @@ -350,7 +285,7 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4 layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels - layers.append(nn.Flatten()) + # layers.append(nn.Flatten()) self.net = nn.Sequential(*layers) def forward(self, X): @@ -371,7 +306,7 @@ def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_si out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step if i == len(kernel_sizes) - 1: out_channels = 3 - layers.append(nn.ConvTranspose2d(in_channels, 3, kernel_size=k, stride=2)) + layers.append(nn.ConvTranspose2d(in_channels, 4, kernel_size=k, stride=2)) else: layers.append(norm_layer(1, in_channels)) layers.append( @@ -384,7 +319,8 @@ def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_si def forward(self, X): x = self.convin(X) x = x.view(-1, 32 * self.channel_step, 1, 1) - return td.Independent(td.Normal(self.net(x), 1.0), 3) + return self.net(x) + # return td.Independent(td.Normal(self.net(x), 1.0), 3) class ViTDecoder(nn.Module): @@ -439,13 +375,15 @@ class WorldModel(nn.Module): def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float): + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, + slots_num: int): super().__init__() self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim self.latent_dim = latent_dim self.latent_classes = latent_classes + self.slots_num = slots_num self.cluster_size = batch_cluster_size self.actions_num = actions_num # kl loss balancing (prior/posterior) @@ -483,6 +421,17 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss else: self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + self.n_dim = 192 + self.slot_attention = SlotAttention(slots_num, 1536, 3) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + + self.slot_mlp = nn.Sequential( + nn.Linear(192, 768), + nn.ReLU(inplace=True), + nn.Linear(768, 1536) + ) + if decode_vit: self.dino_predictor = ViTDecoder(rssm_dim + latent_dim * latent_classes, @@ -497,14 +446,14 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - self.reward_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, + self.reward_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), 1, hidden_size=400, num_layers=5, intermediate_activation=nn.ELU, layer_norm=layer_norm, final_activation=DistLayer('mse')) - self.discount_predictor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, + self.discount_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), 1, hidden_size=400, num_layers=5, @@ -515,9 +464,9 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device - return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), - torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), - torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) + return State(torch.zeros(seq_size, batch_size, self.slots_num, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes * self.latent_dim, device=device)) def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) @@ -529,20 +478,29 @@ def predict_next(self, prev_state: State, action): discount_factors = torch.ones_like(reward) return prior, reward, discount_factors - def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State], prev_slots: t.Optional[torch.Tensor]) -> t.Tuple[State, torch.Tensor]: if state is None: state = self.get_initial_state() embed = self.encoder(obs.unsqueeze(0)) - _, posterior, _ = self.recurrent_model.forward(state, embed.unsqueeze(0), - action) - return posterior + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) + return posterior, None def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) - embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + pre_slot_features = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, self.cluster_size, -1, 1536) a_c = a.reshape(-1, self.cluster_size, self.actions_num) r_c = r.reshape(-1, self.cluster_size, 1) @@ -575,12 +533,16 @@ def KL(dist1, dist2): d_features = self.dino_vit(inp) prev_state = self.get_initial_state(b // self.cluster_size) + prev_slots = None for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, t], a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) a_t = a_t * (1 - first_t) - prior, posterior, diff = self.recurrent_model.forward(prev_state, embed_t, a_t) + slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) + prev_slots = None + + prior, posterior, diff = self.recurrent_model.forward(prev_state, slots_t.unsqueeze(0), a_t) prev_state = posterior priors.append(prior) @@ -597,19 +559,24 @@ def KL(dist1, dist2): losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) if not self.decode_vit: - x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() else: - if self.vit_l2_ratio != 1.0: - x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) - img_rec = -x_r.log_prob(obs).float().mean() - else: - img_rec = 0 - x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) - losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() - d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + - (1-self.vit_l2_ratio) * img_rec) + raise NotImplementedError("") + # if self.vit_l2_ratio != 1.0: + # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) + # img_rec = -x_r.log_prob(obs).float().mean() + # else: + # img_rec = 0 + # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) + # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) + # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + + # (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits @@ -714,6 +681,7 @@ def __init__( encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, + slots_num: int, device_type: str = 'cpu', logger = None): @@ -733,9 +701,9 @@ def __init__( kl_loss_balancing, kl_loss_free_nats, discrete_rssm, world_model_predict_discount, layer_norm, - encode_vit, decode_vit, vit_l2_ratio).to(device_type) + encode_vit, decode_vit, vit_l2_ratio, slots_num).to(device_type) - self.actor = fc_nn_generator(rssm_dim + latent_dim * latent_classes, + self.actor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), actions_num if self.is_discrete else actions_num * 2, 400, 5, @@ -746,7 +714,7 @@ def __init__( self.critic = ImaginativeCritic(discount_factor, critic_update_interval, critic_soft_update_fraction, critic_value_target_lambda, - rssm_dim + latent_dim * latent_classes, + slots_num*(rssm_dim + latent_dim * latent_classes), layer_norm=layer_norm).to(device_type) self.scaler = torch.cuda.amp.GradScaler() @@ -754,10 +722,20 @@ def __init__( lr=world_model_lr, eps=1e-5, weight_decay=1e-6) + self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), lr=world_model_lr, eps=1e-5, weight_decay=1e-6) + + warmup_steps = 1e4 + decay_rate = 0.5 + decay_steps = 5e5 + lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(self.world_model_optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) + lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(self.world_model_optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) + # lr_decay_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate**(1/decay_steps)) + self.lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) + self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(), lr=actor_lr, eps=1e-5, @@ -800,8 +778,9 @@ def imagine_trajectory( def reset(self): self._state = self.world_model.get_initial_state() + self._prev_slots = None self._last_action = torch.zeros((1, 1, self.actions_num), device=self.device) - self._latent_probs = torch.zeros((32, 32), device=self.device) + self._latent_probs = torch.zeros((self.world_model.latent_classes, self.world_model.latent_dim), device=self.device) self._action_probs = torch.zeros((self.actions_num), device=self.device) self._stored_steps = 0 @@ -820,17 +799,18 @@ def preprocess_obs(self, obs: torch.Tensor): def get_action(self, obs: Observation) -> Action: # NOTE: pytorch fails without .copy() only when get_action is called + # FIXME: return back action selection obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs) - self._state = self.world_model.get_latent(obs, self._last_action, self._state) + self._state, self._prev_slots = self.world_model.get_latent(obs, self._last_action, self._state, self._prev_slots) actor_dist = self.actor(self._state.combined) self._last_action = actor_dist.sample() if self.is_discrete: - self._action_probs += actor_dist.probs.squeeze() - self._latent_probs += self._state.stoch_dist.probs.squeeze() + self._action_probs += actor_dist.probs.squeeze().mean(dim=0) + self._latent_probs += self._state.stoch_dist.probs.squeeze().mean(dim=0) self._stored_steps += 1 if self.is_discrete: @@ -845,9 +825,11 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ if self.is_discrete: actions = F.one_hot(actions.to(torch.int64), num_classes=self.actions_num).squeeze() video = [] + slots_video = [] rews = [] state = None + prev_slots = None means = np.array([0.485, 0.456, 0.406]) stds = np.array([0.229, 0.224, 0.225]) UnNormalize = tv.transforms.Normalize(list(-means/stds), @@ -855,28 +837,43 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ for idx, (o, a) in enumerate(list(zip(obs, actions))): if idx > update_num: break - state = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) - video_r = self.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() + state, prev_slots = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state, prev_slots) + # video_r = self.world_model.image_predictor(state.combined_slots).mode.cpu().detach().numpy() + + decoded_imgs, masks = self.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + rews.append(self.world_model.reward_predictor(state.combined).mode.item()) if self.world_model.encode_vit: video_r = UnNormalize(torch.from_numpy(video_r)).numpy() else: video_r = (video_r + 0.5) video.append(video_r) + slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) rews = torch.Tensor(rews).to(obs.device) if update_num < len(obs): states, _, rews_2, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) - video_r = self.world_model.image_predictor(states.combined[1:]).mode.cpu().detach().numpy() + + # video_r = self.world_model.image_predictor(states.combined_slots[1:]).mode.cpu().detach().numpy() + decoded_imgs, masks = self.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + if self.world_model.encode_vit: video_r = UnNormalize(torch.from_numpy(video_r)).numpy() else: video_r = (video_r + 0.5) video.append(video_r) + slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) - return np.concatenate(video), rews + return np.concatenate(video), rews, np.concatenate(slots_video) def viz_log(self, rollout, logger, epoch_num): init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 5) @@ -888,12 +885,16 @@ def viz_log(self, rollout, logger, epoch_num): real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon] for idx in init_indeces] - videos_r, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.imagination_horizon//3) for obs_0, a_0 in zip( + videos_r, imagined_rewards, slots_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.imagination_horizon//3) for obs_0, a_0 in zip( [rollout.next_states[idx:idx+ self.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) ]) videos_r = np.concatenate(videos_r, axis=3) + slots_video = np.concatenate(list(slots_video)[:3], axis=3) + slots_video = slots_video.transpose((0, 2, 3, 1, 4)) + slots_video = np.expand_dims(slots_video.reshape(*slots_video.shape[:-2], -1), 0) + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) videos_comparison = (videos_comparison * 255.0).astype(np.uint8) latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() @@ -909,9 +910,10 @@ def viz_log(self, rollout, logger, epoch_num): else: # log mean +- std pass - logger.add_image('val/latent_probs', latent_hist, epoch_num) - logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num) + logger.add_image('val/latent_probs', latent_hist, epoch_num, dataformats='HW') + logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + logger.add_video('val/dreamed_slots', slots_video, epoch_num) rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) @@ -958,10 +960,13 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # FIXME: clip gradient should be parametrized self.scaler.unscale_(self.world_model_optimizer) + # for tag, value in self.world_model.named_parameters(): + # wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) self.scaler.step(self.world_model_optimizer) + self.lr_scheduler.step() - metrics = {} + metrics = wm_metrics with torch.cuda.amp.autocast(enabled=True): losses_ac = {} diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index a36ca9b..07967c6 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -1,13 +1,14 @@ _target_: rl_sandbox.agents.DreamerV2 -layer_norm: false +layer_norm: true # World model parameters batch_cluster_size: 50 latent_dim: 32 latent_classes: 32 rssm_dim: 200 +slots_num: 1 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 -kl_loss_free_nats: 1.0 +kl_loss_free_nats: 0.0 world_model_lr: 3e-4 world_model_predict_discount: false @@ -15,13 +16,13 @@ world_model_predict_discount: false discount_factor: 0.999 imagination_horizon: 15 -actor_lr: 3e-4 +actor_lr: 8e-5 # mixing of reinforce and maximizing value func # for dm_control it is zero in Dreamer (Atari 1) actor_reinforce_fraction: null actor_entropy_scale: 1e-4 -critic_lr: 3e-4 +critic_lr: 8e-5 # Lambda parameter for trainin deeper multi-step prediction critic_value_target_lambda: 0.95 critic_update_interval: 100 @@ -29,6 +30,6 @@ critic_update_interval: 100 critic_soft_update_fraction: 1 discrete_rssm: false -decode_vit: true +decode_vit: false vit_l2_ratio: 1.0 encode_vit: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index c029dd1..84cd395 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,6 +1,6 @@ defaults: - agent: dreamer_v2 - - env: dm_acrobot + - env: dm_cartpole - training: dm - _self_ @@ -8,20 +8,26 @@ seed: 42 device_type: cuda logger: - type: tensorboard - message: Acrobot reference with 1 nat + type: null + message: Cartpole with only 1 slot without KL + #message: test_last log_grads: false training: checkpoint_path: null - steps: 1.5e6 - batch_size: 50 - val_logs_every: 2e4 + steps: 1e6 + #prefill: 0 + pretrain: 0 + batch_size: 16 + val_logs_every: 2.5e3 validation: rollout_num: 5 visualize: true +#agents: +# batch_cluster_size: 10 + debug: profiler: false diff --git a/rl_sandbox/config/training/dm.yaml b/rl_sandbox/config/training/dm.yaml index fc2bfc1..fd32015 100644 --- a/rl_sandbox/config/training/dm.yaml +++ b/rl_sandbox/config/training/dm.yaml @@ -4,5 +4,5 @@ batch_size: 16 pretrain: 100 prioritize_ends: false train_every: 5 -save_checkpoint_every: 2e5 +save_checkpoint_every: 2e6 val_logs_every: 2e4 diff --git a/rl_sandbox/utils/dists.py b/rl_sandbox/utils/dists.py index 57214fa..8f7032d 100644 --- a/rl_sandbox/utils/dists.py +++ b/rl_sandbox/utils/dists.py @@ -1,8 +1,12 @@ # Taken from https://raw.githubusercontent.com/toshas/torch_truncnorm/main/TruncatedNormal.py +# Added torch modules on top import math from numbers import Number +import typing as t import torch +import torch.distributions as td +from torch import nn from torch.distributions import Distribution, constraints from torch.distributions.utils import broadcast_all @@ -134,3 +138,78 @@ def icdf(self, value): def log_prob(self, value): return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale + + +class Sigmoid2(nn.Module): + def forward(self, x): + return 2*torch.sigmoid(x/2) + +class NormalWithOffset(nn.Module): + def __init__(self, min_std: float, std_trans: str = 'sigmoid2', transform: t.Optional[str] = None): + super().__init__() + self.min_std = min_std + match std_trans: + case 'identity': + self.std_trans = nn.Identity() + case 'softplus': + self.std_trans = nn.Softplus() + case 'sigmoid': + self.std_trans = nn.Sigmoid() + case 'sigmoid2': + self.std_trans = Sigmoid2() + case _: + raise RuntimeError("Unknown std transformation") + + match transform: + case 'tanh': + self.trans = [td.TanhTransform(cache_size=1)] + case None: + self.trans = None + case _: + raise RuntimeError("Unknown distribution transformation") + + def forward(self, x): + mean, std = x.chunk(2, dim=-1) + dist = td.Normal(mean, self.std_trans(std) + self.min_std) + if self.trans is None: + return dist + else: + return td.TransformedDistribution(dist, self.trans) + +class DistLayer(nn.Module): + def __init__(self, type: str): + super().__init__() + self._dist = type + match type: + case 'mse': + self.dist = lambda x: td.Normal(x.float(), 1.0) + case 'normal': + self.dist = NormalWithOffset(min_std=0.1) + case 'onehot': + # Forcing float32 on AMP + self.dist = lambda x: td.OneHotCategoricalStraightThrough(logits=x.float()) + case 'normal_tanh': + def get_tanh_normal(x, min_std=0.1): + mean, std = x.chunk(2, dim=-1) + init_std = np.log(np.exp(5) - 1) + raise NotImplementedError() + # return TanhNormal(torch.clamp(mean, -9.0, 9.0).float(), (F.softplus(std + init_std) + min_std).float(), upscale=5) + self.dist = get_tanh_normal + case 'normal_trunc': + def get_trunc_normal(x, min_std=0.1): + mean, std = x.chunk(2, dim=-1) + return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float(), a=-1, b=1) + self.dist = get_trunc_normal + case 'binary': + self.dist = lambda x: td.Bernoulli(logits=x) + case _: + raise RuntimeError("Invalid dist layer") + + def forward(self, x): + match self._dist: + case 'onehot': + return self.dist(x) + case _: + # FIXME: verify dimensionality of independent + return td.Independent(self.dist(x), 1) + diff --git a/rl_sandbox/utils/logger.py b/rl_sandbox/utils/logger.py index c3981d8..09f6f1c 100644 --- a/rl_sandbox/utils/logger.py +++ b/rl_sandbox/utils/logger.py @@ -50,8 +50,8 @@ def log(self, losses: dict[str, t.Any], global_step: int, mode: str = 'train'): def add_scalar(self, name: str, value: t.Any, global_step: int): self.writer.add_scalar(name, value, global_step) - def add_image(self, name: str, image: t.Any, global_step: int): - self.writer.add_image(name, image, global_step) + def add_image(self, name: str, image: t.Any, global_step: int, dataformats: str = 'CHW'): + self.writer.add_image(name, image, global_step, dataformats=dataformats) def add_video(self, name: str, video: t.Any, global_step: int): self.writer.add_video(name, video, global_step, fps=20) diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py index 5952379..87f36c1 100644 --- a/rl_sandbox/vision/dino.py +++ b/rl_sandbox/vision/dino.py @@ -131,7 +131,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) - return x, qkv, attn + return x, attn class Block(nn.Module): @@ -147,9 +147,9 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, return_attention=False): - y, qkv, attn = self.attn(self.norm1(x)) + y, attn = self.attn(self.norm1(x)) if return_attention: - return qkv, attn + return attn x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x @@ -319,17 +319,17 @@ def forward(self, img) : def hook_fn_forward_qkv(module, input, output): feat_out["qkv"] = output - # self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) + self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) # Forward pass in the model with torch.no_grad() : h, w = img.shape[2], img.shape[3] feat_h, feat_w = h // self.patch_size, w // self.patch_size - qkv, attentions = self.model.get_last_selfattention(img) + attentions = self.model.get_last_selfattention(img) bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] qkv = ( - qkv + feat_out["qkv"] .reshape(bs, nb_token, 3, nb_head, -1) .permute(2, 0, 3, 1, 4) ) diff --git a/rl_sandbox/vision/my_slot_attention.py b/rl_sandbox/vision/my_slot_attention.py index b83bc11..f0fdd81 100644 --- a/rl_sandbox/vision/my_slot_attention.py +++ b/rl_sandbox/vision/my_slot_attention.py @@ -112,13 +112,13 @@ def __init__(self, num_slots: int, n_iter: int, dino_inp_size: int = 224): self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) self.slot_attention = SlotAttention(num_slots, self.n_dim, n_iter) self.img_decoder = nn.Sequential( # Dx8x8 -> (3+1)x64x64 - nn.ConvTranspose2d(self.n_dim, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(self.n_dim, 48, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(48, 96, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(96, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(192, 4, kernel_size=3, stride=(1, 1), padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1), ) self.vit_decoder = nn.Sequential( # Dx1x1 -> (384+1)x14x14 @@ -170,7 +170,7 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[ slots_grid = slots_grid.repeat((1, 1, 8, 8)) # -> batch*num_slots D sqrt(D) sqrt(D) slots_with_pos_enc = self.positional_augmenter_dec(slots_grid) - decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *(np.array(X.shape[2:])//2), 4).split([3, 1], dim=-1) + decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *np.array(X.shape[2:]), 4).split([3, 1], dim=-1) img_mask = F.softmax(masks, dim=1) decoded_imgs = decoded_imgs * img_mask @@ -185,7 +185,7 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[ if __name__ == '__main__': device = 'cuda' - debug = True + debug = False ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) @@ -235,7 +235,7 @@ def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[ res = model(img.to(device)) recovered_img, vit_rec_loss = res['rec_img'], res['vit_rec_loss'] - reg_loss = F.mse_loss(img.to(device)[:, :, ::2, ::2], recovered_img) + reg_loss = F.mse_loss(img.to(device), recovered_img) lambda_ = 0.1 loss = lambda_ * reg_loss + (1 - lambda_) * vit_rec_loss From d99db5f219f0d8210b582b92ccec86f46b1d5fe4 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sat, 15 Apr 2023 10:19:45 +0000 Subject: [PATCH 059/106] Renamed --- .vimspector.json | 2 +- rl_sandbox/agents/dreamer_v2.py | 2 +- rl_sandbox/vision/my_slot_attention.py | 278 ----------------------- rl_sandbox/vision/slot_attention.py | 295 +++++++++++++++++++++---- 4 files changed, 254 insertions(+), 323 deletions(-) delete mode 100644 rl_sandbox/vision/my_slot_attention.py diff --git a/.vimspector.json b/.vimspector.json index 3945f69..44116be 100644 --- a/.vimspector.json +++ b/.vimspector.json @@ -46,7 +46,7 @@ "Run dino": { "extends": "python-base", "configuration": { - "program": "rl_sandbox/vision/my_slot_attention.py", + "program": "rl_sandbox/vision/slot_attention.py", "args": [] } } diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 64933fc..f48ba81 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -21,7 +21,7 @@ TerminationFlags, IsFirstFlags) from rl_sandbox.utils.schedulers import LinearScheduler from rl_sandbox.utils.dists import DistLayer -from rl_sandbox.vision.my_slot_attention import SlotAttention, PositionalEmbedding +from rl_sandbox.vision.slot_attention import SlotAttention, PositionalEmbedding class View(nn.Module): diff --git a/rl_sandbox/vision/my_slot_attention.py b/rl_sandbox/vision/my_slot_attention.py deleted file mode 100644 index f0fdd81..0000000 --- a/rl_sandbox/vision/my_slot_attention.py +++ /dev/null @@ -1,278 +0,0 @@ -import torch -import typing as t -from torch import nn -import torch.nn.functional as F -from jaxtyping import Float -import torchvision as tv -from tqdm import tqdm -import numpy as np - -from rl_sandbox.vision.dino import ViTFeat -from rl_sandbox.utils.logger import Logger - -class SlotAttention(nn.Module): - def __init__(self, num_slots: int, n_dim: int, n_iter: int): - super().__init__() - - self.n_slots = num_slots - self.n_iter = n_iter - self.n_dim = n_dim - self.scale = self.n_dim**(-1/2) - self.epsilon = 1e-8 - - self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) - self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) - nn.init.xavier_uniform_(self.slots_logsigma) - - self.slots_proj = nn.Linear(n_dim, n_dim) - self.slots_proj_2 = nn.Sequential( - nn.Linear(n_dim, n_dim*4), - nn.ReLU(inplace=True), - nn.Linear(n_dim*4, n_dim), - ) - self.slots_norm = nn.LayerNorm(self.n_dim) - self.slots_norm_2 = nn.LayerNorm(self.n_dim) - self.slots_reccur = nn.GRUCell(input_size=self.n_dim, hidden_size=self.n_dim) - - self.inputs_proj = nn.Linear(n_dim, n_dim*2) - self.inputs_norm = nn.LayerNorm(self.n_dim) - - def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']]) -> Float[torch.Tensor, 'batch num_slots n_dim']: - batch, _, _ = X.shape - k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) - - if prev_slots is None: - slots = self.slots_mu + self.slots_logsigma.exp() * torch.randn(batch, self.n_slots, self.n_dim, device=X.device) - else: - slots = prev_slots - - for _ in range(self.n_iter): - slots_prev = slots - slots = self.slots_norm(slots) - q = self.slots_proj(slots) - - attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k), dim=1) + self.epsilon - attn = attn / attn.sum(dim=-1, keepdim=True) - - updates = torch.einsum('bij,bjk->bik', attn, v) / self.n_slots - slots = self.slots_reccur(updates.reshape(-1, self.n_dim), slots_prev.reshape(-1, self.n_dim)).reshape(batch, self.n_slots, self.n_dim) - slots = slots + self.slots_proj_2(self.slots_norm_2(slots)) - return slots - -def build_grid(resolution): - ranges = [np.linspace(0., 1., num=res) for res in resolution] - grid = np.meshgrid(*ranges, sparse=False, indexing="ij") - grid = np.stack(grid, axis=-1) - grid = np.reshape(grid, [resolution[0], resolution[1], -1]) - grid = np.expand_dims(grid, axis=0) - grid = grid.astype(np.float32) - return np.concatenate([grid, 1.0 - grid], axis=-1) - - -class PositionalEmbedding(nn.Module): - def __init__(self, n_dim: int, res: t.Tuple[int, int]): - super().__init__() - self.n_dim = n_dim - self.proj = nn.Linear(4, n_dim) - self.register_buffer('grid', torch.from_numpy(build_grid(res))) - - def forward(self, X) -> torch.Tensor: - return X + self.proj(self.grid).permute(0, 3, 1, 2) - -class SlottedAutoEncoder(nn.Module): - def __init__(self, num_slots: int, n_iter: int, dino_inp_size: int = 224): - super().__init__() - in_channels = 3 - self.n_dim = 128 - self.lat_dim = int(self.n_dim**0.5) - self.encoder = nn.Sequential( - nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(64, self.n_dim, kernel_size=5, padding='same'), - nn.ReLU(inplace=True), - ) - - self.mlp = nn.Sequential( - nn.Linear(self.n_dim, self.n_dim), - nn.ReLU(inplace=True), - nn.Linear(self.n_dim, self.n_dim) - ) - - self.dino_inp_size = dino_inp_size - self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) - self.vit_patch_num = self.dino_inp_size // self.dino_vit.patch_size - self.vit_feat = self.dino_vit.feat_dim - - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (7, 7)) - self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) - self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) - self.slot_attention = SlotAttention(num_slots, self.n_dim, n_iter) - self.img_decoder = nn.Sequential( # Dx8x8 -> (3+1)x64x64 - nn.ConvTranspose2d(self.n_dim, 48, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(48, 96, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(96, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ConvTranspose2d(192, 4, kernel_size=3, stride=(1, 1), padding=1), - nn.ReLU(inplace=True), - ) - - self.vit_decoder = nn.Sequential( # Dx1x1 -> (384+1)x14x14 - nn.ConvTranspose2d(self.n_dim, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(192, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(192, self.vit_feat, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(self.vit_feat, self.vit_feat*2, kernel_size=3, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(self.vit_feat*2, self.vit_feat+1, kernel_size=3, stride=(1, 1), padding=1), - ) - - # self.vit_decoder_mlp = nn.Sequential( - # nn.Linear(self.n_dim, 1024), - # nn.ReLU(inplace=True), - # nn.Linear(1024, 1024), - # nn.ReLU(inplace=True), - # nn.Linear(1024, 1024), - # nn.ReLU(inplace=True), - # nn.Linear(1024, self.vit_feat+1), - # nn.ReLU(inplace=True) - # ) - - def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']] = None) -> t.Dict[str, torch.Tensor]: - features = self.encoder(X) # -> batch D h w - features_with_pos_enc = self.positional_augmenter_inp(features) # -> batch D h w - - resize = tv.transforms.Resize(self.dino_inp_size, antialias=True) - - batch, seq, _, _ = X.shape - vit_features = self.dino_vit(resize(X)) - vit_features = vit_features.reshape(batch, -1, self.vit_patch_num, self.vit_patch_num) - - pre_slot_features = self.mlp(features_with_pos_enc.permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)) - - slots = self.slot_attention(pre_slot_features, prev_slots) # -> batch num_slots D - slots_grid = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim).permute(0, 3, 1, 2) - - # slots_with_vit_pos_enc = self.positional_augmenter_vit_dec(slots_grid.flatten(2, 3).repeat((1, 1, 196)).reshape(-1, self.n_dim, self.lat_dim, self.lat_dim)).flatten(2, 3) - # decoded_features, vit_masks =self.vit_decoder_mlp(slots_with_vit_pos_enc).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) - - decoded_features, vit_masks = self.vit_decoder(slots_grid).permute(0, 2, 3, 1).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) - vit_mask = F.softmax(vit_masks, dim=1) - - rec_features = (decoded_features * vit_mask).sum(dim=1) - - slots_grid = slots_grid.repeat((1, 1, 8, 8)) # -> batch*num_slots D sqrt(D) sqrt(D) - slots_with_pos_enc = self.positional_augmenter_dec(slots_grid) - - decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *np.array(X.shape[2:]), 4).split([3, 1], dim=-1) - img_mask = F.softmax(masks, dim=1) - - decoded_imgs = decoded_imgs * img_mask - rec_img = torch.sum(decoded_imgs, dim=1) - return { - 'rec_img': rec_img.permute(0, 3, 1, 2), - 'img_per_slot': decoded_imgs.permute(0, 1, 4, 2, 3), - 'vit_mask': vit_mask, - 'vit_rec_loss': F.mse_loss(rec_features.permute(0, 3, 1, 2), vit_features), - 'slots': slots - } - -if __name__ == '__main__': - device = 'cuda' - debug = False - ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), - tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) - - train_data = tv.datasets.ImageFolder('~/rl_sandbox/crafter_data/', transform=ToTensor) - if debug: - train_data_loader = torch.utils.data.DataLoader(train_data, - batch_size=4, - prefetch_factor=1, - shuffle=False, - num_workers=2) - else: - train_data_loader = torch.utils.data.DataLoader(train_data, - batch_size=32, - shuffle=True, - num_workers=8) - - import socket - from datetime import datetime - current_time = datetime.now().strftime("%b%d_%H-%M-%S") - comment = "Added vit masks logging, lambda=0.1, return old dino".replace(" ", "_") - logger = Logger(None if debug else 'tensorboard', message=comment, log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}_{comment}") - - number_of_slots = 7 - slots_iter_num = 3 - - total_steps = 5e5 - warmup_steps = 1e4 - decay_rate = 0.5 - decay_steps = 1e5 - val_every = 1e4 - - model = SlottedAutoEncoder(number_of_slots, slots_iter_num).to(device) - # model = torch.compile(model) - optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) - lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) - lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) - # lr_decay_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate**(1/decay_steps)) - lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) - - global_step = 0 - prev_global_step = 0 - epoch = 0 - pbar = tqdm(total=total_steps, desc='Training') - while global_step < total_steps: - for sample_num, (img, target) in enumerate(train_data_loader): - res = model(img.to(device)) - recovered_img, vit_rec_loss = res['rec_img'], res['vit_rec_loss'] - - reg_loss = F.mse_loss(img.to(device), recovered_img) - - lambda_ = 0.1 - loss = lambda_ * reg_loss + (1 - lambda_) * vit_rec_loss - - optimizer.zero_grad() - loss.backward() - optimizer.step() - lr_scheduler.step() - - logger.add_scalar('train/img_rec_loss', reg_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) - logger.add_scalar('train/vit_rec_loss', vit_rec_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) - logger.add_scalar('train/loss', loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) - pbar.update(1) - global_step += len(train_data_loader) - - epoch += 1 - logger.add_scalar('epoch', epoch, epoch) - - if global_step - prev_global_step > val_every: - prev_global_step = global_step - else: - continue - - for i in range(3): - img, target = next(iter(train_data_loader)) - res = model(img.to(device)) - recovered_img, imgs_per_slot, vit_mask = res['rec_img'], res['img_per_slot'], res['vit_mask'] - upscale = tv.transforms.Resize(64, antialias=True) - unnormalize = tv.transforms.Compose([ - tv.transforms.Normalize((0, 0, 0), (1/0.229, 1/0.224, 1/0.225)), - tv.transforms.Normalize((-0.485, -0.456, -0.406), (1., 1., 1.)) - ]) - logger.add_image(f'val/example_image', unnormalize(img.cpu().detach()[0]), epoch*3 + i) - logger.add_image(f'val/example_image_rec', unnormalize(recovered_img.cpu().detach()[0]), epoch*3 + i) - per_slot_img = unnormalize(imgs_per_slot.cpu().detach())[0].permute((1, 2, 0, 3)).flatten(2, 3) - logger.add_image(f'val/example_image_slot_rec', per_slot_img, epoch*3 + i) - upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) - per_slot_vit = (upscaled_mask.unsqueeze(2) * img.to(device).unsqueeze(1))[0].permute(1, 2, 0, 3).flatten(2, 3) - logger.add_image(f'val/example_vit_slot_mask', per_slot_vit/upscaled_mask.max(), epoch*3 + i) - diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index 4caefe9..f0fdd81 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -1,69 +1,278 @@ import torch +import typing as t from torch import nn -from torch.nn import init +import torch.nn.functional as F +from jaxtyping import Float +import torchvision as tv +from tqdm import tqdm +import numpy as np + +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.utils.logger import Logger class SlotAttention(nn.Module): - def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128): + def __init__(self, num_slots: int, n_dim: int, n_iter: int): super().__init__() - self.num_slots = num_slots - self.iters = iters - self.eps = eps - self.scale = dim ** -0.5 - self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) + self.n_slots = num_slots + self.n_iter = n_iter + self.n_dim = n_dim + self.scale = self.n_dim**(-1/2) + self.epsilon = 1e-8 + + self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + nn.init.xavier_uniform_(self.slots_logsigma) + + self.slots_proj = nn.Linear(n_dim, n_dim) + self.slots_proj_2 = nn.Sequential( + nn.Linear(n_dim, n_dim*4), + nn.ReLU(inplace=True), + nn.Linear(n_dim*4, n_dim), + ) + self.slots_norm = nn.LayerNorm(self.n_dim) + self.slots_norm_2 = nn.LayerNorm(self.n_dim) + self.slots_reccur = nn.GRUCell(input_size=self.n_dim, hidden_size=self.n_dim) + + self.inputs_proj = nn.Linear(n_dim, n_dim*2) + self.inputs_norm = nn.LayerNorm(self.n_dim) + + def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']]) -> Float[torch.Tensor, 'batch num_slots n_dim']: + batch, _, _ = X.shape + k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) + + if prev_slots is None: + slots = self.slots_mu + self.slots_logsigma.exp() * torch.randn(batch, self.n_slots, self.n_dim, device=X.device) + else: + slots = prev_slots + + for _ in range(self.n_iter): + slots_prev = slots + slots = self.slots_norm(slots) + q = self.slots_proj(slots) + + attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k), dim=1) + self.epsilon + attn = attn / attn.sum(dim=-1, keepdim=True) + + updates = torch.einsum('bij,bjk->bik', attn, v) / self.n_slots + slots = self.slots_reccur(updates.reshape(-1, self.n_dim), slots_prev.reshape(-1, self.n_dim)).reshape(batch, self.n_slots, self.n_dim) + slots = slots + self.slots_proj_2(self.slots_norm_2(slots)) + return slots - self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim)) - init.xavier_uniform_(self.slots_logsigma) +def build_grid(resolution): + ranges = [np.linspace(0., 1., num=res) for res in resolution] + grid = np.meshgrid(*ranges, sparse=False, indexing="ij") + grid = np.stack(grid, axis=-1) + grid = np.reshape(grid, [resolution[0], resolution[1], -1]) + grid = np.expand_dims(grid, axis=0) + grid = grid.astype(np.float32) + return np.concatenate([grid, 1.0 - grid], axis=-1) - self.to_q = nn.Linear(dim, dim) - self.to_k = nn.Linear(dim, dim) - self.to_v = nn.Linear(dim, dim) - self.gru = nn.GRUCell(dim, dim) +class PositionalEmbedding(nn.Module): + def __init__(self, n_dim: int, res: t.Tuple[int, int]): + super().__init__() + self.n_dim = n_dim + self.proj = nn.Linear(4, n_dim) + self.register_buffer('grid', torch.from_numpy(build_grid(res))) + + def forward(self, X) -> torch.Tensor: + return X + self.proj(self.grid).permute(0, 3, 1, 2) - hidden_dim = max(dim, hidden_dim) +class SlottedAutoEncoder(nn.Module): + def __init__(self, num_slots: int, n_iter: int, dino_inp_size: int = 224): + super().__init__() + in_channels = 3 + self.n_dim = 128 + self.lat_dim = int(self.n_dim**0.5) + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, self.n_dim, kernel_size=5, padding='same'), + nn.ReLU(inplace=True), + ) self.mlp = nn.Sequential( - nn.Linear(dim, hidden_dim), - nn.ReLU(inplace = True), - nn.Linear(hidden_dim, dim) + nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim) ) - self.norm_input = nn.LayerNorm(dim) - self.norm_slots = nn.LayerNorm(dim) - self.norm_pre_ff = nn.LayerNorm(dim) + self.dino_inp_size = dino_inp_size + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) + self.vit_patch_num = self.dino_inp_size // self.dino_vit.patch_size + self.vit_feat = self.dino_vit.feat_dim - def forward(self, inputs, num_slots = None): - b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype - n_s = num_slots if num_slots is not None else self.num_slots + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (7, 7)) + self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + self.positional_augmenter_vit_dec = PositionalEmbedding(self.n_dim, (self.lat_dim, self.lat_dim)) + self.slot_attention = SlotAttention(num_slots, self.n_dim, n_iter) + self.img_decoder = nn.Sequential( # Dx8x8 -> (3+1)x64x64 + nn.ConvTranspose2d(self.n_dim, 48, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(48, 96, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(96, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ConvTranspose2d(192, 4, kernel_size=3, stride=(1, 1), padding=1), + nn.ReLU(inplace=True), + ) - mu = self.slots_mu.expand(b, n_s, -1) - sigma = self.slots_logsigma.exp().expand(b, n_s, -1) + self.vit_decoder = nn.Sequential( # Dx1x1 -> (384+1)x14x14 + nn.ConvTranspose2d(self.n_dim, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, 192, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(192, self.vit_feat, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(self.vit_feat, self.vit_feat*2, kernel_size=3, stride=(2, 2), padding=2, output_padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(self.vit_feat*2, self.vit_feat+1, kernel_size=3, stride=(1, 1), padding=1), + ) - slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype) + # self.vit_decoder_mlp = nn.Sequential( + # nn.Linear(self.n_dim, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, 1024), + # nn.ReLU(inplace=True), + # nn.Linear(1024, self.vit_feat+1), + # nn.ReLU(inplace=True) + # ) - inputs = self.norm_input(inputs) - k, v = self.to_k(inputs), self.to_v(inputs) + def forward(self, X: Float[torch.Tensor, 'batch 3 h w'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']] = None) -> t.Dict[str, torch.Tensor]: + features = self.encoder(X) # -> batch D h w + features_with_pos_enc = self.positional_augmenter_inp(features) # -> batch D h w - for _ in range(self.iters): - slots_prev = slots + resize = tv.transforms.Resize(self.dino_inp_size, antialias=True) - slots = self.norm_slots(slots) - q = self.to_q(slots) + batch, seq, _, _ = X.shape + vit_features = self.dino_vit(resize(X)) + vit_features = vit_features.reshape(batch, -1, self.vit_patch_num, self.vit_patch_num) - dots = torch.einsum('bid,bjd->bij', q, k) * self.scale - attn = dots.softmax(dim=1) + self.eps + pre_slot_features = self.mlp(features_with_pos_enc.permute(0, 2, 3, 1).reshape(batch, -1, self.n_dim)) - attn = attn / attn.sum(dim=-1, keepdim=True) + slots = self.slot_attention(pre_slot_features, prev_slots) # -> batch num_slots D + slots_grid = slots.flatten(0, 1).reshape(-1, 1, 1, self.n_dim).permute(0, 3, 1, 2) - updates = torch.einsum('bjd,bij->bid', v, attn) + # slots_with_vit_pos_enc = self.positional_augmenter_vit_dec(slots_grid.flatten(2, 3).repeat((1, 1, 196)).reshape(-1, self.n_dim, self.lat_dim, self.lat_dim)).flatten(2, 3) + # decoded_features, vit_masks =self.vit_decoder_mlp(slots_with_vit_pos_enc).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) - slots = self.gru( - updates.reshape(-1, d), - slots_prev.reshape(-1, d) - ) + decoded_features, vit_masks = self.vit_decoder(slots_grid).permute(0, 2, 3, 1).reshape(batch, -1, self.vit_patch_num, self.vit_patch_num, self.vit_feat+1).split([self.vit_feat, 1], dim=-1) + vit_mask = F.softmax(vit_masks, dim=1) - slots = slots.reshape(b, -1, d) - slots = slots + self.mlp(self.norm_pre_ff(slots)) + rec_features = (decoded_features * vit_mask).sum(dim=1) + + slots_grid = slots_grid.repeat((1, 1, 8, 8)) # -> batch*num_slots D sqrt(D) sqrt(D) + slots_with_pos_enc = self.positional_augmenter_dec(slots_grid) + + decoded_imgs, masks = self.img_decoder(slots_with_pos_enc).permute(0, 2, 3, 1).reshape(batch, -1, *np.array(X.shape[2:]), 4).split([3, 1], dim=-1) + img_mask = F.softmax(masks, dim=1) + + decoded_imgs = decoded_imgs * img_mask + rec_img = torch.sum(decoded_imgs, dim=1) + return { + 'rec_img': rec_img.permute(0, 3, 1, 2), + 'img_per_slot': decoded_imgs.permute(0, 1, 4, 2, 3), + 'vit_mask': vit_mask, + 'vit_rec_loss': F.mse_loss(rec_features.permute(0, 3, 1, 2), vit_features), + 'slots': slots + } + +if __name__ == '__main__': + device = 'cuda' + debug = False + ToTensor = tv.transforms.Compose([tv.transforms.ToTensor(), + tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + train_data = tv.datasets.ImageFolder('~/rl_sandbox/crafter_data/', transform=ToTensor) + if debug: + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=4, + prefetch_factor=1, + shuffle=False, + num_workers=2) + else: + train_data_loader = torch.utils.data.DataLoader(train_data, + batch_size=32, + shuffle=True, + num_workers=8) + + import socket + from datetime import datetime + current_time = datetime.now().strftime("%b%d_%H-%M-%S") + comment = "Added vit masks logging, lambda=0.1, return old dino".replace(" ", "_") + logger = Logger(None if debug else 'tensorboard', message=comment, log_dir=f"vae_tmp/{current_time}_{socket.gethostname()}_{comment}") + + number_of_slots = 7 + slots_iter_num = 3 + + total_steps = 5e5 + warmup_steps = 1e4 + decay_rate = 0.5 + decay_steps = 1e5 + val_every = 1e4 + + model = SlottedAutoEncoder(number_of_slots, slots_iter_num).to(device) + # model = torch.compile(model) + optimizer = torch.optim.Adam(model.parameters(), lr=4e-4) + lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) + lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) + # lr_decay_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate**(1/decay_steps)) + lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) + + global_step = 0 + prev_global_step = 0 + epoch = 0 + pbar = tqdm(total=total_steps, desc='Training') + while global_step < total_steps: + for sample_num, (img, target) in enumerate(train_data_loader): + res = model(img.to(device)) + recovered_img, vit_rec_loss = res['rec_img'], res['vit_rec_loss'] + + reg_loss = F.mse_loss(img.to(device), recovered_img) + + lambda_ = 0.1 + loss = lambda_ * reg_loss + (1 - lambda_) * vit_rec_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + logger.add_scalar('train/img_rec_loss', reg_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/vit_rec_loss', vit_rec_loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + logger.add_scalar('train/loss', loss.cpu().detach(), epoch * len(train_data_loader) + sample_num) + pbar.update(1) + global_step += len(train_data_loader) + + epoch += 1 + logger.add_scalar('epoch', epoch, epoch) + + if global_step - prev_global_step > val_every: + prev_global_step = global_step + else: + continue + + for i in range(3): + img, target = next(iter(train_data_loader)) + res = model(img.to(device)) + recovered_img, imgs_per_slot, vit_mask = res['rec_img'], res['img_per_slot'], res['vit_mask'] + upscale = tv.transforms.Resize(64, antialias=True) + unnormalize = tv.transforms.Compose([ + tv.transforms.Normalize((0, 0, 0), (1/0.229, 1/0.224, 1/0.225)), + tv.transforms.Normalize((-0.485, -0.456, -0.406), (1., 1., 1.)) + ]) + logger.add_image(f'val/example_image', unnormalize(img.cpu().detach()[0]), epoch*3 + i) + logger.add_image(f'val/example_image_rec', unnormalize(recovered_img.cpu().detach()[0]), epoch*3 + i) + per_slot_img = unnormalize(imgs_per_slot.cpu().detach())[0].permute((1, 2, 0, 3)).flatten(2, 3) + logger.add_image(f'val/example_image_slot_rec', per_slot_img, epoch*3 + i) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * img.to(device).unsqueeze(1))[0].permute(1, 2, 0, 3).flatten(2, 3) + logger.add_image(f'val/example_vit_slot_mask', per_slot_vit/upscaled_mask.max(), epoch*3 + i) - return slots From ad411de49cfafb3bbe67cfc909aa68fa7aa7f8bc Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 23 Apr 2023 14:06:40 +0100 Subject: [PATCH 060/106] Added dockerfiles with installation guide --- Dockerfile | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 18 +++++++++++++++ pyproject.toml | 2 +- 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 Dockerfile create mode 100644 README.md diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..cd179cf --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +ARG BASE_IMAGE=nvidia/cudagl:11.3.0-devel +FROM $BASE_IMAGE + +ARG USER_ID +ARG GROUP_ID +ARG USER_NAME=user + +RUN apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y ssh gcc g++ gdb clang rsync tar python sudo git ffmpeg ninja-build locales \ + && apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +RUN ( \ + echo 'LogLevel DEBUG2'; \ + echo 'PermitRootLogin yes'; \ + echo 'PasswordAuthentication yes'; \ + echo 'Subsystem sftp /usr/lib/openssh/sftp-server'; \ + ) > /etc/ssh/sshd_config_test_clion \ + && mkdir /run/sshd + +RUN groupadd -g ${GROUP_ID} ${USER_NAME} && \ + useradd -u ${USER_ID} -g ${GROUP_ID} -s /bin/bash -m ${USER_NAME} && \ + yes password | passwd ${USER_NAME} && \ + usermod -aG sudo ${USER_NAME} && \ + echo "${USER_NAME} ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/user && \ + chmod 440 /etc/sudoers + +USER ${USER_NAME} + +RUN git clone https://github.com/Midren/dotfiles /home/${USER_NAME}/.dotfiles && \ + /home/${USER_NAME}/.dotfiles/install-profile ubuntu-cli + +RUN git config --global user.email "milromchuk@gmail.com" && \ + git config --global user.name "Roman Milishchuk" + +USER root + +RUN apt-get update \ + && apt-get install -y software-properties-common curl \ + && add-apt-repository -y ppa:deadsnakes/ppa \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y python3.10 python3.10-dev python3.10-venv \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 \ + && apt-get clean \ + && sudo rm -rf /var/lib/apt/lists/* + +RUN sudo update-alternatives --install /usr/bin/python3 python /usr/bin/python3.10 1 \ + && sudo update-alternatives --install /usr/bin/python python3 /usr/bin/python3.10 1 + +USER ${USER_NAME} +WORKDIR /home/${USER_NAME}/ + +RUN mkdir /home/${USER_NAME}/rl_sandbox + +COPY pyproject.toml /home/${USER_NAME}/rl_sandbox/pyproject.toml +COPY rl_sandbox /home/${USER_NAME}/rl_sandbox/rl_sandbox + +RUN cd /home/${USER_NAME}/rl_sandbox \ + && python3.10 -m pip install --no-cache-dir -e . \ + && rm -Rf /home/${USER_NAME}/.cache/pip + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..bfa5b71 --- /dev/null +++ b/README.md @@ -0,0 +1,18 @@ +## RL sandbox + +## Run + +Build docker: +```sh +docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) --build-arg USER_NAME=$USER -t dreamer . +``` + +Run docker with tty: +```sh +docker run --gpus 'all' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer zsh +``` + +Run training inside docker on gpu 0: +```sh +docker run --gpus 'device=0' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer python3 rl_sandbox/train.py +``` diff --git a/pyproject.toml b/pyproject.toml index 94c5203..e0790ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ gym = "0.25.0" # crafter requires old step api pygame = '*' moviepy = '*' torchvision = '*' -torch = '*' +torch = '^2.0' tensorboard = '^2.0' dm-control = '^1.0.0' unpackable = '^0.0.4' From 7d2f6b1d0c18d3472f341c5c54a7f6259fef9d81 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 23 Apr 2023 15:46:13 +0100 Subject: [PATCH 061/106] Slot attention debug --- rl_sandbox/agents/dreamer_v2.py | 22 +++++++++++----------- rl_sandbox/config/agent/dreamer_v2.yaml | 10 +++++----- rl_sandbox/config/config.yaml | 8 ++++---- rl_sandbox/train.py | 2 +- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index f48ba81..65b60fe 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -210,8 +210,8 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret # For observation we do not have ensemble # FIXME: very bad magic number - img_sz = 4 * 384 # 384*2x2 - # img_sz = 192 + # img_sz = 4 * 384 # 384x2x2 + img_sz = 192 self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' @@ -422,14 +422,14 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) self.n_dim = 192 - self.slot_attention = SlotAttention(slots_num, 1536, 3) + self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) self.slot_mlp = nn.Sequential( - nn.Linear(192, 768), + nn.Linear(self.n_dim, self.n_dim), nn.ReLU(inplace=True), - nn.Linear(768, 1536) + nn.Linear(self.n_dim, self.n_dim) ) @@ -489,7 +489,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State], prev_s slots_t = self.slot_attention(pre_slot_features_t, prev_slots) _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) - return posterior, None + return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor): @@ -500,7 +500,7 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) pre_slot_features = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) - pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, self.cluster_size, -1, 1536) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, self.cluster_size, -1, self.n_dim) a_c = a.reshape(-1, self.cluster_size, self.actions_num) r_c = r.reshape(-1, self.cluster_size, 1) @@ -540,7 +540,7 @@ def KL(dist1, dist2): a_t = a_t * (1 - first_t) slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) - prev_slots = None + # prev_slots = None prior, posterior, diff = self.recurrent_model.forward(prev_state, slots_t.unsqueeze(0), a_t) prev_state = posterior @@ -728,7 +728,7 @@ def __init__( eps=1e-5, weight_decay=1e-6) - warmup_steps = 1e4 + warmup_steps = 1e3 decay_rate = 0.5 decay_steps = 5e5 lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(self.world_model_optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) @@ -936,7 +936,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation first_flags = self.from_np(is_first).type(torch.float32) # take some latent embeddings as initial - with torch.cuda.amp.autocast(enabled=True): + with torch.cuda.amp.autocast(enabled=False): losses, discovered_states, wm_metrics = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags) self.world_model.recurrent_model.discretizer_scheduler.step() @@ -968,7 +968,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation metrics = wm_metrics - with torch.cuda.amp.autocast(enabled=True): + with torch.cuda.amp.autocast(enabled=False): losses_ac = {} initial_states = State(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_logits.flatten(0, 1).unsqueeze(0).detach(), diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 07967c6..c1ac64c 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -2,11 +2,11 @@ _target_: rl_sandbox.agents.DreamerV2 layer_norm: true # World model parameters batch_cluster_size: 50 -latent_dim: 32 -latent_classes: 32 -rssm_dim: 200 -slots_num: 1 -kl_loss_scale: 1.0 +latent_dim: 16 +latent_classes: 16 +rssm_dim: 40 +slots_num: 8 +kl_loss_scale: 8.0 kl_loss_balancing: 0.8 kl_loss_free_nats: 0.0 world_model_lr: 3e-4 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 84cd395..a3af136 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -8,8 +8,8 @@ seed: 42 device_type: cuda logger: - type: null - message: Cartpole with only 1 slot without KL + type: tensorboard + message: Cartpole 8 slots, reduced warmup, 192 n_dim, correct prev_slots, 8x KL, 0 nats, 40 rssm dims, 16x16 stoch #message: test_last log_grads: false @@ -17,9 +17,9 @@ training: checkpoint_path: null steps: 1e6 #prefill: 0 - pretrain: 0 + #pretrain: 0 batch_size: 16 - val_logs_every: 2.5e3 + val_logs_every: 1e2 validation: diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index a6fb2d5..8781461 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -136,7 +136,7 @@ def main(cfg: DictConfig): # FIXME: find more appealing solution ### Validation - if (global_step % cfg.training.val_logs_every) < (prev_global_step % + if (global_step % cfg.training.val_logs_every) <= (prev_global_step % cfg.training.val_logs_every): val_logs(agent, cfg.validation, val_env, global_step, logger) From e5b48a2f52cf19a433bcbc990841af10cf314873 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 23 Apr 2023 18:13:39 +0100 Subject: [PATCH 062/106] Added parameter sweep configuration to run on each of gpu --- rl_sandbox/config/config.yaml | 9 +++++++++ rl_sandbox/train.py | 10 +++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index a3af136..6538b89 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -3,6 +3,7 @@ defaults: - env: dm_cartpole - training: dm - _self_ + - override hydra/launcher: joblib seed: 42 device_type: cuda @@ -31,3 +32,11 @@ validation: debug: profiler: false + +hydra: + mode: MULTIRUN + launcher: + n_jobs: 1 + sweeper: + params: [] + #agent.kl_loss_scale: 1.0,2.0,4.0,8.0,12.0,16.0,22.0,32.0 diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 8781461..04532ba 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -9,6 +9,8 @@ import torch from gym.spaces import Discrete from omegaconf import DictConfig +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode from torch.profiler import ProfilerActivity, profile from tqdm import tqdm @@ -45,7 +47,7 @@ def main(cfg: DictConfig): lt.monkey_patch() # print(OmegaConf.to_yaml(cfg)) torch.distributions.Distribution.set_default_validate_args(False) - torch.backends.cudnn.benchmark = True + eval('setattr(torch.backends.cudnn, "benchmark", True)') # need to be pickable for multirun torch.backends.cuda.matmul.allow_tf32 = True if torch.cuda.is_available(): @@ -54,6 +56,12 @@ def main(cfg: DictConfig): torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) + if HydraConfig.get()['mode'] == RunMode.MULTIRUN and cfg.device_type == 'cuda': + num_gpus = torch.cuda.device_count() + gpu_id = HydraConfig.get().job.num % num_gpus + cfg.device_type = f'cuda:{gpu_id}' + cfg.logger.message += "," + ",".join(HydraConfig.get()['overrides']['task']) + # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) logger = Logger(**cfg.logger) From e4b2ec4d159b9e0e00a0cf6aef79afea2c5a39ec Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 24 Apr 2023 07:48:33 +0100 Subject: [PATCH 063/106] Added bigger encoder --- rl_sandbox/agents/dreamer_v2.py | 16 +++++++++++----- rl_sandbox/config/config.yaml | 11 ++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 65b60fe..8b71c52 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -179,7 +179,7 @@ class RSSM(nn.Module): """ - def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity): + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity, embed_size = 2*2*384): super().__init__() self.latent_dim = latent_dim self.latent_classes = latent_classes @@ -211,7 +211,8 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret # For observation we do not have ensemble # FIXME: very bad magic number # img_sz = 4 * 384 # 384x2x2 - img_sz = 192 + # img_sz = 192 + img_sz = embed_size self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' @@ -277,13 +278,16 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4 super().__init__() layers = [] - channel_step = 48 + channel_step = 96 in_channels = 3 for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) + layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) in_channels = out_channels # layers.append(nn.Flatten()) self.net = nn.Sequential(*layers) @@ -393,12 +397,15 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.n_dim = 384 + self.recurrent_model = RSSM(latent_dim, rssm_dim, actions_num, latent_classes, discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm) + norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + embed_size=self.n_dim) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) @@ -421,7 +428,6 @@ def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rss else: self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - self.n_dim = 192 self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 6538b89..6315b4b 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -10,7 +10,7 @@ device_type: cuda logger: type: tensorboard - message: Cartpole 8 slots, reduced warmup, 192 n_dim, correct prev_slots, 8x KL, 0 nats, 40 rssm dims, 16x16 stoch + message: Cartpole 4 slots, double encoder, reduced warmup, 384 n_dim, correct prev_slots, 0 nats, 40 rssm dims, 16x16 stoch #message: test_last log_grads: false @@ -20,7 +20,7 @@ training: #prefill: 0 #pretrain: 0 batch_size: 16 - val_logs_every: 1e2 + val_logs_every: 4e3 validation: @@ -35,8 +35,9 @@ debug: hydra: mode: MULTIRUN + #mode: RUN launcher: - n_jobs: 1 + n_jobs: 8 sweeper: - params: [] - #agent.kl_loss_scale: 1.0,2.0,4.0,8.0,12.0,16.0,22.0,32.0 + params: + agent.kl_loss_scale: 4.0,8.0,12.0,16.0,22.0,28.0,32.0,48.0 From 1f66d019d7422d6c36f0390f1808a07e91d36226 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 25 Apr 2023 16:11:49 +0100 Subject: [PATCH 064/106] parameters without decoder collapse --- rl_sandbox/config/agent/dreamer_v2.yaml | 4 ++-- rl_sandbox/config/config.yaml | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index c1ac64c..20c1718 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -6,9 +6,9 @@ latent_dim: 16 latent_classes: 16 rssm_dim: 40 slots_num: 8 -kl_loss_scale: 8.0 +kl_loss_scale: 32.0 kl_loss_balancing: 0.8 -kl_loss_free_nats: 0.0 +kl_loss_free_nats: 0.00 world_model_lr: 3e-4 world_model_predict_discount: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 6315b4b..0737b54 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -10,7 +10,7 @@ device_type: cuda logger: type: tensorboard - message: Cartpole 4 slots, double encoder, reduced warmup, 384 n_dim, correct prev_slots, 0 nats, 40 rssm dims, 16x16 stoch + message: 32 KL, Cartpole 8 slots, double encoder, reduced warmup, 384 n_dim, correct prev_slots, 0.00 nats, 40 rssm dims, 16x16 stoch #message: test_last log_grads: false @@ -34,10 +34,11 @@ debug: profiler: false hydra: - mode: MULTIRUN - #mode: RUN + #mode: MULTIRUN + mode: RUN launcher: - n_jobs: 8 + #n_jobs: 8 + n_jobs: 1 sweeper: params: - agent.kl_loss_scale: 4.0,8.0,12.0,16.0,22.0,28.0,32.0,48.0 + #agent.kl_loss_scale: 1.0,4.0,16.0,32.0,64.0,96.0,128.0 From ff83255e414e50e41b3740ad69969dffa7f53034 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Wed, 26 Apr 2023 16:35:39 +0000 Subject: [PATCH 065/106] Refactored to have separate entities in hydra --- pyproject.toml | 2 + rl_sandbox/agents/dreamer/__init__.py | 1 + rl_sandbox/agents/dreamer/ac.py | 137 ++++ rl_sandbox/agents/dreamer/common.py | 71 ++ rl_sandbox/agents/dreamer/rssm.py | 208 ++++++ rl_sandbox/agents/dreamer/vision.py | 89 +++ rl_sandbox/agents/dreamer/world_model.py | 249 +++++++ rl_sandbox/agents/dreamer_v2.py | 819 +---------------------- rl_sandbox/config/agent/dreamer_v2.yaml | 98 ++- rl_sandbox/utils/logger.py | 4 +- rl_sandbox/utils/optimizer.py | 71 ++ 11 files changed, 934 insertions(+), 815 deletions(-) create mode 100644 rl_sandbox/agents/dreamer/__init__.py create mode 100644 rl_sandbox/agents/dreamer/ac.py create mode 100644 rl_sandbox/agents/dreamer/common.py create mode 100644 rl_sandbox/agents/dreamer/rssm.py create mode 100644 rl_sandbox/agents/dreamer/vision.py create mode 100644 rl_sandbox/agents/dreamer/world_model.py create mode 100644 rl_sandbox/utils/optimizer.py diff --git a/pyproject.toml b/pyproject.toml index e0790ff..256f066 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ authors = ['Roman Milishchuk '] packages = [{include = 'rl_sandbox'}] # add config directory as package data +# TODO: add yapf and isort as development dependencies [tool.poetry.dependencies] python = "^3.10" numpy = '*' @@ -29,6 +30,7 @@ jaxtyping = '^0.2.0' lovely_tensors = '^0.1.10' torchshow = '^0.5.0' crafter = '^1.8.0' +hydra-joblib-launcher = "*" [tool.yapf] based_on_style = "pep8" diff --git a/rl_sandbox/agents/dreamer/__init__.py b/rl_sandbox/agents/dreamer/__init__.py new file mode 100644 index 0000000..55e5f84 --- /dev/null +++ b/rl_sandbox/agents/dreamer/__init__.py @@ -0,0 +1 @@ +from .common import * diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py new file mode 100644 index 0000000..515e7e6 --- /dev/null +++ b/rl_sandbox/agents/dreamer/ac.py @@ -0,0 +1,137 @@ +import typing as t + +import torch +import torch.distributions as td +from torch import nn + +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator + + +class ImaginativeCritic(nn.Module): + + def __init__(self, discount_factor: float, update_interval: int, + soft_update_fraction: float, value_target_lambda: float, latent_dim: int, + layer_norm: bool): + super().__init__() + self.gamma = discount_factor + self.critic_update_interval = update_interval + self.lambda_ = value_target_lambda + self.critic_soft_update_fraction = soft_update_fraction + self._update_num = 0 + + self.critic = fc_nn_generator(latent_dim, + 1, + 400, + 5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.target_critic = fc_nn_generator(latent_dim, + 1, + 400, + 5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.target_critic.requires_grad_(False) + + def update_target(self): + if self._update_num == 0: + self.target_critic.load_state_dict(self.critic.state_dict()) + # for target_param, local_param in zip(self.target_critic.parameters(), + # self.critic.parameters()): + # mix = self.critic_soft_update_fraction + # target_param.data.copy_(mix * local_param.data + + # (1 - mix) * target_param.data) + self._update_num = (self._update_num + 1) % self.critic_update_interval + + def estimate_value(self, z) -> td.Distribution: + return self.critic(z) + + def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): + # Formula is actually slightly different than in paper + # https://github.com/danijar/dreamerv2/issues/25 + v_lambdas = [vs[-1]] + for i in range(rs.shape[0] - 1, -1, -1): + v_lambda = rs[i] + ds[i] * ( + (1 - self.lambda_) * vs[i + 1] + self.lambda_ * v_lambdas[-1]) + v_lambdas.append(v_lambda) + + # FIXME: it copies array, so it is quite slow + return torch.stack(v_lambdas).flip(dims=(0, ))[:-1] + + def lambda_return(self, zs, rs, ds): + vs = self.target_critic(zs).mode + return self._lambda_return(vs, rs, ds) + + def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, + discount_factors: torch.Tensor): + predicted_vs_dist = self.estimate_value(zs.detach()) + losses = { + 'loss_critic': + -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(2) * + discount_factors).mean() + } + metrics = { + 'critic/avg_target_value': self.target_critic(zs).mode.mean(), + 'critic/avg_lambda_value': vs.mean(), + 'critic/avg_predicted_value': predicted_vs_dist.mode.mean() + } + return losses, metrics + + +class ImaginativeActor(nn.Module): + + def __init__(self, latent_dim: int, actions_num: int, is_discrete: bool, + layer_norm: bool, reinforce_fraction: t.Optional[float], + entropy_scale: float): + super().__init__() + self.rho = reinforce_fraction + if self.rho is None: + self.rho = is_discrete + self.eta = entropy_scale + self.actor = fc_nn_generator( + latent_dim, + actions_num if is_discrete else actions_num * 2, + 400, + 5, + layer_norm=layer_norm, + intermediate_activation=nn.ELU, + final_activation=DistLayer('onehot' if is_discrete else 'normal_trunc')) + + def forward(self, z: torch.Tensor) -> td.Distribution: + return self.actor(z) + + def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, baseline: torch.Tensor, + discount_factors: torch.Tensor, actions: torch.Tensor): + losses = {} + metrics = {} + action_dists = self.actor(zs.detach()) + # baseline = + advantage = (vs - baseline).detach() + losses['loss_actor_reinforce'] = -(self.rho * action_dists.log_prob( + actions.detach()).unsqueeze(2) * discount_factors * advantage).mean() + losses['loss_actor_dynamics_backprop'] = -((1 - self.rho) * + (vs * discount_factors)).mean() + + def calculate_entropy(dist): + # return dist.base_dist.base_dist.entropy().unsqueeze(2) + return dist.entropy().unsqueeze(2) + + losses['loss_actor_entropy'] = -(self.eta * calculate_entropy(action_dists) * + discount_factors).mean() + losses['loss_actor'] = losses['loss_actor_reinforce'] + losses[ + 'loss_actor_dynamics_backprop'] + losses['loss_actor_entropy'] + + # mean and std are estimated statistically as tanh transformation is used + sample = action_dists.rsample((128, )) + act_avg = sample.mean(0) + metrics['actor/avg_val'] = act_avg.mean() + # metrics['actor/mode_val'] = action_dists.mode.mean() + metrics['actor/mean_val'] = action_dists.mean.mean() + metrics['actor/avg_sd'] = (((sample - act_avg)**2).mean(0).sqrt()).mean() + metrics['actor/min_val'] = sample.min() + metrics['actor/max_val'] = sample.max() + + return losses, metrics diff --git a/rl_sandbox/agents/dreamer/common.py b/rl_sandbox/agents/dreamer/common.py new file mode 100644 index 0000000..86ff51f --- /dev/null +++ b/rl_sandbox/agents/dreamer/common.py @@ -0,0 +1,71 @@ +import torch +import typing as t +from dataclasses import dataclass +from jaxtyping import Float, Bool +from torch import nn +from rl_sandbox.utils.dists import DistLayer + +class View(nn.Module): + + def __init__(self, shape): + super().__init__() + self.shape = shape + + def forward(self, x): + return x.view(*self.shape) + +def Dist(val): + return DistLayer('onehot')(val) + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + + @property + def combined_slots(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1,)) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim = 0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs) + +class Normalizer(nn.Module): + def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): + super().__init__() + self.momentum = momentum + self.scale = scale + self.eps= eps + self.register_buffer('mag', torch.ones(1, dtype=torch.float32)) + self.mag.requires_grad = False + + def forward(self, x): + self.update(x) + return (x / (self.mag + self.eps))*self.scale + + def update(self, x): + self.mag = self.momentum * self.mag + (1 - self.momentum) * (x.abs().mean()).detach() + + diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py new file mode 100644 index 0000000..f7de2f9 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -0,0 +1,208 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.utils.schedulers import LinearScheduler +from rl_sandbox.agents.dreamer import State, View + +class GRUCell(nn.Module): + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=norm is not None, **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + embed = torch.randn(dim, n_embed) + self.inp_in = nn.Linear(1024, self.n_embed*self.dim) + self.inp_out = nn.Linear(self.n_embed*self.dim, 1024) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, inp): + # input = self.inp_in(inp).reshape(-1, 1, self.n_embed, self.dim) + input = inp.reshape(-1, 1, self.n_embed, self.dim) + inp = input + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + self.cluster_size.data.mul_(self.decay).add_( + embed_onehot_sum, alpha=1 - self.decay + ) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ( + (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) + quantize_out = quantize + diff = 0.25*(quantize_out.detach() - inp).pow(2).mean() + (quantize_out - inp.detach()).pow(2).mean() + quantize = inp + (quantize_out - inp).detach() + + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity, embed_size = 2*2*384): + super().__init__() + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True) + ) + self.determ_recurrent = GRUCell(input_size=hidden_size, hidden_size=hidden_size, norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) for _ in range(self.ensemble_num) + ]) + + # For observation we do not have ensemble + # FIXME: very bad magic number + # img_sz = 4 * 384 # 384x2x2 + # img_sz = 192 + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + # self.determ_discretizer = MlpVAE(self.hidden_size) + self.determ_discretizer = Quantize(16, 16) + self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) + self.determ_layer_norm = nn.LayerNorm(hidden_size) + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] + # NOTE: Maybe something smarter can be used instead of + # taking only one random between all ensembles + # NOTE: in Dreamer ensemble_num is always 1 + idx = torch.randint(0, self.ensemble_num, ()) + return dists_per_model[0] + + def predict_next(self, + prev_state: State, + action) -> State: + x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1))], dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(1, 2), prev_state.determ.flatten(1, 2)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + # determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + # determ_post = determ_post.reshape(determ_prior.shape) + # determ_post = self.determ_layer_norm(determ_post) + # alpha = self.discretizer_scheduler.val + # determ_post = alpha * determ_prior + (1-alpha) * determ_post + else: + determ_post, diff = determ_prior, 0 + + # used for KL divergence + predicted_stoch_logits = self.estimate_stochastic_latent(x) + # Size is 1 x B x slots_num x ... + return State(determ_post.reshape(prev_state.determ.shape), predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten(1, 2).reshape(prior.stoch_logits.shape)) + + def forward(self, h_prev: State, embed, + action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py new file mode 100644 index 0000000..b04be2d --- /dev/null +++ b/rl_sandbox/agents/dreamer/vision.py @@ -0,0 +1,89 @@ +import torch.distributions as td +from torch import nn + +class Encoder(nn.Module): + + def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4]): + super().__init__() + layers = [] + + channel_step = 96 + in_channels = 3 + for i, k in enumerate(kernel_sizes): + out_channels = 2**i * channel_step + layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + # layers.append(nn.Flatten()) + self.net = nn.Sequential(*layers) + + def forward(self, X): + return self.net(X) + + +class Decoder(nn.Module): + + def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 6, 6]): + super().__init__() + layers = [] + self.channel_step = 48 + # 2**(len(kernel_sizes)-1)*channel_step + self.convin = nn.Linear(input_size, 32 * self.channel_step) + + in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step + for i, k in enumerate(kernel_sizes): + out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step + if i == len(kernel_sizes) - 1: + out_channels = 3 + layers.append(nn.ConvTranspose2d(in_channels, 4, kernel_size=k, stride=2)) + else: + layers.append(norm_layer(1, in_channels)) + layers.append( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, + stride=2)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, 32 * self.channel_step, 1, 1) + return self.net(x) + # return td.Independent(td.Normal(self.net(x), 1.0), 3) + +class ViTDecoder(nn.Module): + + # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): + # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 5, 3]): + def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 3]): + super().__init__() + layers = [] + self.channel_step = 12 + # 2**(len(kernel_sizes)-1)*channel_step + self.convin = nn.Linear(input_size, 32 * self.channel_step) + + in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step + for i, k in enumerate(kernel_sizes): + out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step + if i == len(kernel_sizes) - 1: + out_channels = 3 + layers.append(nn.ConvTranspose2d(in_channels, 384, kernel_size=k, stride=1, padding=1)) + else: + layers.append(norm_layer(1, in_channels)) + layers.append( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2, padding=2, output_padding=1)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, 32 * self.channel_step, 1, 1) + return td.Independent(td.Normal(self.net(x), 1.0), 3) + + + diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py new file mode 100644 index 0000000..0f59e25 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -0,0 +1,249 @@ +import typing as t +import torch +import torch.distributions as td +from torch import nn +from torch.nn import functional as F +import torchvision as tv +from rl_sandbox.vision.dino import ViTFeat + +from rl_sandbox.utils.fc_nn import fc_nn_generator + +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.vision.slot_attention import SlotAttention, PositionalEmbedding + +from rl_sandbox.agents.dreamer import Dist, State, Normalizer +from rl_sandbox.agents.dreamer.rssm import RSSM +from rl_sandbox.agents.dreamer.vision import Encoder, Decoder, ViTDecoder + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, + slots_num: int): + super().__init__() + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + + self.n_dim = 384 + + self.recurrent_model = RSSM(latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + embed_size=self.n_dim) + if encode_vit or decode_vit: + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) + self.vit_feat_dim = self.dino_vit.feat_dim + self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.encoder = nn.Sequential( + self.dino_vit, + nn.Flatten(), + # fc_nn_generator(64*self.dino_vit.feat_dim, + # 64*384, + # hidden_size=400, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm) + ) + else: + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + + self.slot_mlp = nn.Sequential( + nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim) + ) + + + if decode_vit: + self.dino_predictor = ViTDecoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + # 64*self.dino_vit.feat_dim, + # hidden_size=2048, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm, + # final_activation=DistLayer('mse')) + self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + + self.reward_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + return State(torch.zeros(seq_size, batch_size, self.slots_num, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes * self.latent_dim, device=device)) + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).sample() + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State], prev_slots: t.Optional[torch.Tensor]) -> t.Tuple[State, torch.Tensor]: + if state is None: + state = self.get_initial_state() + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor): + b, _, h, w = obs.shape # s <- BxHxWx3 + + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + pre_slot_features = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_rhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2), td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + inp = obs + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(224, antialias=True)]) + # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)) + inp = ToTensor(obs + 0.5) + d_features = self.dino_vit(inp) + + prev_state = self.get_initial_state(b // self.cluster_size) + prev_slots = None + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, t], a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + a_t = a_t * (1 - first_t) + + slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) + # prev_slots = None + + prior, posterior, diff = self.recurrent_model.forward(prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + raise NotImplementedError("") + # if self.vit_l2_ratio != 1.0: + # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) + # img_rec = -x_r.log_prob(obs).float().mean() + # else: + # img_rec = 0 + # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) + # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) + # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + + # (1-self.vit_l2_ratio) * img_rec) + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + + losses['loss_reward_pred'] + + losses['loss_kl_reg'] + + losses['loss_discount_pred']) + + return losses, posterior, metrics + + diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 8b71c52..943db4a 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,660 +1,23 @@ import typing as t -from collections import defaultdict -from dataclasses import dataclass from pathlib import Path -from functools import partial import matplotlib.pyplot as plt import numpy as np import torch -import torch.distributions as td from torch import nn from torch.nn import functional as F import torchvision as tv -from jaxtyping import Float, Bool -from rl_sandbox.vision.dino import ViTFeat from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, Observations, Rewards, TerminationFlags, IsFirstFlags) -from rl_sandbox.utils.schedulers import LinearScheduler -from rl_sandbox.utils.dists import DistLayer -from rl_sandbox.vision.slot_attention import SlotAttention, PositionalEmbedding +from rl_sandbox.utils.optimizer import Optimizer -class View(nn.Module): - - def __init__(self, shape): - super().__init__() - self.shape = shape - - def forward(self, x): - return x.view(*self.shape) - -def Dist(val): - return DistLayer('onehot')(val) - -@dataclass -class State: - determ: Float[torch.Tensor, 'seq batch num_slots determ'] - stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] - stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None - - @property - def combined(self): - return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) - - @property - def combined_slots(self): - return torch.concat([self.determ, self.stoch], dim=-1) - - @property - def stoch(self): - if self.stoch_ is None: - self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1,)) - return self.stoch_ - - @property - def stoch_dist(self): - return Dist(self.stoch_logits) - - @classmethod - def stack(cls, states: list['State'], dim = 0): - if states[0].stoch_ is not None: - stochs = torch.cat([state.stoch for state in states], dim=dim) - else: - stochs = None - return State(torch.cat([state.determ for state in states], dim=dim), - torch.cat([state.stoch_logits for state in states], dim=dim), - stochs) - -class GRUCell(nn.Module): - def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): - super().__init__() - self._size = hidden_size - self._act = torch.tanh - self._norm = norm - self._update_bias = update_bias - self._layer = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=norm is not None, **kwargs) - if norm: - self._norm = nn.LayerNorm(3 * hidden_size) - - @property - def state_size(self): - return self._size - - def forward(self, x, h): - state = h - parts = self._layer(torch.concat([x, state], -1)) - if self._norm: - dtype = parts.dtype - parts = self._norm(parts.float()) - parts = parts.to(dtype=dtype) - reset, cand, update = parts.chunk(3, dim=-1) - reset = torch.sigmoid(reset) - cand = self._act(reset * cand) - update = torch.sigmoid(update + self._update_bias) - output = update * cand + (1 - update) * state - return output, output - - -class Quantize(nn.Module): - def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): - super().__init__() - - self.dim = dim - self.n_embed = n_embed - self.decay = decay - self.eps = eps - - embed = torch.randn(dim, n_embed) - self.inp_in = nn.Linear(1024, self.n_embed*self.dim) - self.inp_out = nn.Linear(self.n_embed*self.dim, 1024) - self.register_buffer("embed", embed) - self.register_buffer("cluster_size", torch.zeros(n_embed)) - self.register_buffer("embed_avg", embed.clone()) - - def forward(self, inp): - # input = self.inp_in(inp).reshape(-1, 1, self.n_embed, self.dim) - input = inp.reshape(-1, 1, self.n_embed, self.dim) - inp = input - flatten = input.reshape(-1, self.dim) - dist = ( - flatten.pow(2).sum(1, keepdim=True) - - 2 * flatten @ self.embed - + self.embed.pow(2).sum(0, keepdim=True) - ) - _, embed_ind = (-dist).max(1) - embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) - embed_ind = embed_ind.view(*input.shape[:-1]) - quantize = self.embed_code(embed_ind) - - if self.training: - embed_onehot_sum = embed_onehot.sum(0) - embed_sum = flatten.transpose(0, 1) @ embed_onehot - - self.cluster_size.data.mul_(self.decay).add_( - embed_onehot_sum, alpha=1 - self.decay - ) - self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) - n = self.cluster_size.sum() - cluster_size = ( - (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) - self.embed.data.copy_(embed_normalized) - - # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) - quantize_out = quantize - diff = 0.25*(quantize_out.detach() - inp).pow(2).mean() + (quantize_out - inp.detach()).pow(2).mean() - quantize = inp + (quantize_out - inp).detach() - - return quantize, diff, embed_ind - - def embed_code(self, embed_id): - return F.embedding(embed_id, self.embed.transpose(0, 1)) - - -class RSSM(nn.Module): - """ - Recurrent State Space Model - - h_t <- deterministic state which is updated inside GRU - s^_t <- stohastic discrete prior state (used for KL divergence: - better predict future and encode smarter) - s_t <- stohastic discrete posterior state (latent representation of current state) - - h_1 ---> h_2 ---> h_3 ---> - \\ x_1 \\ x_2 \\ x_3 - | \\ | ^ | \\ | ^ | \\ | ^ - v MLP CNN | v MLP CNN | v MLP CNN | - \\ | | \\ | | \\ | | - Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | - \\| | \\| | \\| | - | | | | | | | | | - v v | v v | v v | - | | | - s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| - - """ - - def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity, embed_size = 2*2*384): - super().__init__() - self.latent_dim = latent_dim - self.latent_classes = latent_classes - self.ensemble_num = 1 - self.hidden_size = hidden_size - self.discrete_rssm = discrete_rssm - - # Calculate deterministic state from prev stochastic, prev action and prev deterministic - self.pre_determ_recurrent = nn.Sequential( - nn.Linear(latent_dim * latent_classes + actions_num, - hidden_size), # Dreamer 'img_in' - norm_layer(hidden_size), - nn.ELU(inplace=True) - ) - self.determ_recurrent = GRUCell(input_size=hidden_size, hidden_size=hidden_size, norm=True) # Dreamer gru '_cell' - - # Calculate stochastic state from prior embed - # shared between all ensemble models - self.ensemble_prior_estimator = nn.ModuleList([ - nn.Sequential( - nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' - norm_layer(hidden_size), - nn.ELU(inplace=True), - nn.Linear(hidden_size, - latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' - View((1, -1, latent_dim, self.latent_classes))) for _ in range(self.ensemble_num) - ]) - - # For observation we do not have ensemble - # FIXME: very bad magic number - # img_sz = 4 * 384 # 384x2x2 - # img_sz = 192 - img_sz = embed_size - self.stoch_net = nn.Sequential( - # nn.LayerNorm(hidden_size + img_sz, hidden_size), - nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' - norm_layer(hidden_size), - nn.ELU(inplace=True), - nn.Linear(hidden_size, - latent_dim * self.latent_classes), # Dreamer 'obs_dist' - View((1, -1, latent_dim, self.latent_classes))) - # self.determ_discretizer = MlpVAE(self.hidden_size) - self.determ_discretizer = Quantize(16, 16) - self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) - self.determ_layer_norm = nn.LayerNorm(hidden_size) - - def estimate_stochastic_latent(self, prev_determ: torch.Tensor): - dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] - # NOTE: Maybe something smarter can be used instead of - # taking only one random between all ensembles - # NOTE: in Dreamer ensemble_num is always 1 - idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[0] - - def predict_next(self, - prev_state: State, - action) -> State: - x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1))], dim=-1)) - # NOTE: x and determ are actually the same value if sequence of 1 is inserted - x, determ_prior = self.determ_recurrent(x.flatten(1, 2), prev_state.determ.flatten(1, 2)) - if self.discrete_rssm: - raise NotImplementedError("discrete rssm was not adopted for slot attention") - # determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) - # determ_post = determ_post.reshape(determ_prior.shape) - # determ_post = self.determ_layer_norm(determ_post) - # alpha = self.discretizer_scheduler.val - # determ_post = alpha * determ_prior + (1-alpha) * determ_post - else: - determ_post, diff = determ_prior, 0 - - # used for KL divergence - predicted_stoch_logits = self.estimate_stochastic_latent(x) - # Size is 1 x B x slots_num x ... - return State(determ_post.reshape(prev_state.determ.shape), predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff - - def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' - return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten(1, 2).reshape(prior.stoch_logits.shape)) - - def forward(self, h_prev: State, embed, - action) -> tuple[State, State]: - """ - 'h' <- internal state of the world - 'z' <- latent embedding of current observation - 'a' <- action taken on prev step - Returns 'h_next' <- the next next of the world - """ - prior, diff = self.predict_next(h_prev, action) - posterior = self.update_current(prior, embed) - - return prior, posterior, diff - - -class Encoder(nn.Module): - - def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4]): - super().__init__() - layers = [] - - channel_step = 96 - in_channels = 3 - for i, k in enumerate(kernel_sizes): - out_channels = 2**i * channel_step - layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) - layers.append(norm_layer(1, out_channels)) - layers.append(nn.ELU(inplace=True)) - layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) - layers.append(norm_layer(1, out_channels)) - layers.append(nn.ELU(inplace=True)) - in_channels = out_channels - # layers.append(nn.Flatten()) - self.net = nn.Sequential(*layers) - - def forward(self, X): - return self.net(X) - - -class Decoder(nn.Module): - - def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 6, 6]): - super().__init__() - layers = [] - self.channel_step = 48 - # 2**(len(kernel_sizes)-1)*channel_step - self.convin = nn.Linear(input_size, 32 * self.channel_step) - - in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step - for i, k in enumerate(kernel_sizes): - out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step - if i == len(kernel_sizes) - 1: - out_channels = 3 - layers.append(nn.ConvTranspose2d(in_channels, 4, kernel_size=k, stride=2)) - else: - layers.append(norm_layer(1, in_channels)) - layers.append( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, - stride=2)) - layers.append(nn.ELU(inplace=True)) - in_channels = out_channels - self.net = nn.Sequential(*layers) - - def forward(self, X): - x = self.convin(X) - x = x.view(-1, 32 * self.channel_step, 1, 1) - return self.net(x) - # return td.Independent(td.Normal(self.net(x), 1.0), 3) - -class ViTDecoder(nn.Module): - - # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): - # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 5, 3]): - def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 3]): - super().__init__() - layers = [] - self.channel_step = 12 - # 2**(len(kernel_sizes)-1)*channel_step - self.convin = nn.Linear(input_size, 32 * self.channel_step) - - in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step - for i, k in enumerate(kernel_sizes): - out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step - if i == len(kernel_sizes) - 1: - out_channels = 3 - layers.append(nn.ConvTranspose2d(in_channels, 384, kernel_size=k, stride=1, padding=1)) - else: - layers.append(norm_layer(1, in_channels)) - layers.append( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2, padding=2, output_padding=1)) - layers.append(nn.ELU(inplace=True)) - in_channels = out_channels - self.net = nn.Sequential(*layers) - - def forward(self, X): - x = self.convin(X) - x = x.view(-1, 32 * self.channel_step, 1, 1) - return td.Independent(td.Normal(self.net(x), 1.0), 3) - - - -class Normalizer(nn.Module): - def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): - super().__init__() - self.momentum = momentum - self.scale = scale - self.eps= eps - self.register_buffer('mag', torch.ones(1, dtype=torch.float32)) - self.mag.requires_grad = False - - def forward(self, x): - self.update(x) - return (x / (self.mag + self.eps))*self.scale - - def update(self, x): - self.mag = self.momentum * self.mag + (1 - self.momentum) * (x.abs().mean()).detach() - - -class WorldModel(nn.Module): - - def __init__(self, img_size, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, - slots_num: int): - super().__init__() - self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) - self.kl_beta = kl_loss_scale - self.rssm_dim = rssm_dim - self.latent_dim = latent_dim - self.latent_classes = latent_classes - self.slots_num = slots_num - self.cluster_size = batch_cluster_size - self.actions_num = actions_num - # kl loss balancing (prior/posterior) - self.alpha = kl_loss_balancing - self.predict_discount = predict_discount - self.encode_vit = encode_vit - self.decode_vit = decode_vit - self.vit_l2_ratio = vit_l2_ratio - - self.n_dim = 384 - - self.recurrent_model = RSSM(latent_dim, - rssm_dim, - actions_num, - latent_classes, - discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm, - embed_size=self.n_dim) - if encode_vit or decode_vit: - # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) - self.vit_feat_dim = self.dino_vit.feat_dim - self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches - self.dino_vit.requires_grad_(False) - - if encode_vit: - self.encoder = nn.Sequential( - self.dino_vit, - nn.Flatten(), - # fc_nn_generator(64*self.dino_vit.feat_dim, - # 64*384, - # hidden_size=400, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm) - ) - else: - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - - self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) - # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) - - self.slot_mlp = nn.Sequential( - nn.Linear(self.n_dim, self.n_dim), - nn.ReLU(inplace=True), - nn.Linear(self.n_dim, self.n_dim) - ) - - - if decode_vit: - self.dino_predictor = ViTDecoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, - # 64*self.dino_vit.feat_dim, - # hidden_size=2048, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm, - # final_activation=DistLayer('mse')) - self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - - self.reward_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), - 1, - hidden_size=400, - num_layers=5, - intermediate_activation=nn.ELU, - layer_norm=layer_norm, - final_activation=DistLayer('mse')) - self.discount_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), - 1, - hidden_size=400, - num_layers=5, - intermediate_activation=nn.ELU, - layer_norm=layer_norm, - final_activation=DistLayer('binary')) - self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) - - def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): - device = next(self.parameters()).device - return State(torch.zeros(seq_size, batch_size, self.slots_num, self.rssm_dim, device=device), - torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes, self.latent_dim, device=device), - torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes * self.latent_dim, device=device)) - - def predict_next(self, prev_state: State, action): - prior, _ = self.recurrent_model.predict_next(prev_state, action) - - reward = self.reward_predictor(prior.combined).mode - if self.predict_discount: - discount_factors = self.discount_predictor(prior.combined).sample() - else: - discount_factors = torch.ones_like(reward) - return prior, reward, discount_factors - - def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State], prev_slots: t.Optional[torch.Tensor]) -> t.Tuple[State, torch.Tensor]: - if state is None: - state = self.get_initial_state() - embed = self.encoder(obs.unsqueeze(0)) - embed_with_pos_enc = self.positional_augmenter_inp(embed) - - pre_slot_features_t = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) - - slots_t = self.slot_attention(pre_slot_features_t, prev_slots) - - _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) - return posterior, slots_t - - def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - discount: torch.Tensor, first: torch.Tensor): - b, _, h, w = obs.shape # s <- BxHxWx3 - - embed = self.encoder(obs) - embed_with_pos_enc = self.positional_augmenter_inp(embed) - # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) - - pre_slot_features = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) - pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, self.cluster_size, -1, self.n_dim) - - a_c = a.reshape(-1, self.cluster_size, self.actions_num) - r_c = r.reshape(-1, self.cluster_size, 1) - d_c = discount.reshape(-1, self.cluster_size, 1) - first_c = first.reshape(-1, self.cluster_size, 1) - - losses = {} - metrics = {} - - def KL(dist1, dist2): - KL_ = torch.distributions.kl_divergence - kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), td.OneHotCategoricalStraightThrough(logits=dist1)).mean() - kl_rhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2), td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() - kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) - kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) - return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) - - priors = [] - posteriors = [] - - if self.decode_vit: - inp = obs - if not self.encode_vit: - ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)), - tv.transforms.Resize(224, antialias=True)]) - # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)) - inp = ToTensor(obs + 0.5) - d_features = self.dino_vit(inp) - - prev_state = self.get_initial_state(b // self.cluster_size) - prev_slots = None - for t in range(self.cluster_size): - # s_t <- 1xB^xHxWx3 - pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, t], a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) - a_t = a_t * (1 - first_t) - - slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) - # prev_slots = None - - prior, posterior, diff = self.recurrent_model.forward(prev_state, slots_t.unsqueeze(0), a_t) - prev_state = posterior - - priors.append(prior) - posteriors.append(posterior) - - # losses['loss_determ_recons'] += diff - - posterior = State.stack(posteriors) - prior = State.stack(priors) - - r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) - f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) - - losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) - - if not self.decode_vit: - decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * img_mask - x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) - - losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() - else: - raise NotImplementedError("") - # if self.vit_l2_ratio != 1.0: - # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) - # img_rec = -x_r.log_prob(obs).float().mean() - # else: - # img_rec = 0 - # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) - # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() - # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) - # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + - # (1-self.vit_l2_ratio) * img_rec) - - prior_logits = prior.stoch_logits - posterior_logits = posterior.stoch_logits - losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() - losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() - losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) - - metrics['reward_mean'] = r.mean() - metrics['reward_std'] = r.std() - metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() - metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() - metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() - - return losses, posterior, metrics - - -class ImaginativeCritic(nn.Module): - - def __init__(self, discount_factor: float, update_interval: int, - soft_update_fraction: float, value_target_lambda: float, latent_dim: int, layer_norm: bool): - super().__init__() - self.gamma = discount_factor - self.critic_update_interval = update_interval - self.lambda_ = value_target_lambda - self.critic_soft_update_fraction = soft_update_fraction - self._update_num = 0 - - self.critic = fc_nn_generator(latent_dim, - 1, - 400, - 5, - intermediate_activation=nn.ELU, - layer_norm=layer_norm, - final_activation=DistLayer('mse')) - self.target_critic = fc_nn_generator(latent_dim, - 1, - 400, - 5, - intermediate_activation=nn.ELU, - layer_norm=layer_norm, - final_activation=DistLayer('mse')) - self.target_critic.requires_grad_(False) - - def update_target(self): - if self._update_num == 0: - self.target_critic.load_state_dict(self.critic.state_dict()) - # for target_param, local_param in zip(self.target_critic.parameters(), - # self.critic.parameters()): - # mix = self.critic_soft_update_fraction - # target_param.data.copy_(mix * local_param.data + - # (1 - mix) * target_param.data) - self._update_num = (self._update_num + 1) % self.critic_update_interval - - def estimate_value(self, z) -> td.Distribution: - return self.critic(z) - - def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): - # Formula is actually slightly different than in paper - # https://github.com/danijar/dreamerv2/issues/25 - v_lambdas = [vs[-1]] - for i in range(rs.shape[0] - 1, -1, -1): - v_lambda = rs[i] + ds[i] * ( - (1 - self.lambda_) * vs[i+1] + - self.lambda_ * v_lambdas[-1]) - v_lambdas.append(v_lambda) - - # FIXME: it copies array, so it is quite slow - return torch.stack(v_lambdas).flip(dims=(0,))[:-1] - - def lambda_return(self, zs, rs, ds): - vs = self.target_critic(zs).mode - return self._lambda_return(vs, rs, ds) +from rl_sandbox.agents.dreamer import State +from rl_sandbox.agents.dreamer.world_model import WorldModel +from rl_sandbox.agents.dreamer.ac import ImaginativeCritic, ImaginativeActor class DreamerV2(RlAgent): @@ -663,93 +26,36 @@ def __init__( self, obs_space_num: list[int], # NOTE: encoder/decoder will work only with 64x64 currently actions_num: int, + world_model: t.Any, + actor: t.Any, + critic: t.Any, action_type: str, - batch_cluster_size: int, - latent_dim: int, - latent_classes: int, - rssm_dim: int, - discount_factor: float, - kl_loss_scale: float, - kl_loss_balancing: float, - kl_loss_free_nats: float, imagination_horizon: int, - critic_update_interval: int, - actor_reinforce_fraction: float, - actor_entropy_scale: float, - critic_soft_update_fraction: float, - critic_value_target_lambda: float, - world_model_lr: float, - world_model_predict_discount: bool, - actor_lr: float, - critic_lr: float, - discrete_rssm: bool, + wm_optim: t.Any, + actor_optim: t.Any, + critic_optim: t.Any, layer_norm: bool, - encode_vit: bool, - decode_vit: bool, - vit_l2_ratio: float, - slots_num: int, + batch_cluster_size: int, device_type: str = 'cpu', logger = None): self.logger = logger self.device = device_type self.imagination_horizon = imagination_horizon - self.cluster_size = batch_cluster_size self.actions_num = actions_num - self.rho = actor_reinforce_fraction - self.eta = actor_entropy_scale self.is_discrete = (action_type != 'continuous') - if self.rho is None: - self.rho = self.is_discrete - - self.world_model = WorldModel(obs_space_num[0], batch_cluster_size, latent_dim, latent_classes, - rssm_dim, actions_num, kl_loss_scale, - kl_loss_balancing, kl_loss_free_nats, - discrete_rssm, - world_model_predict_discount, layer_norm, - encode_vit, decode_vit, vit_l2_ratio, slots_num).to(device_type) - - self.actor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), - actions_num if self.is_discrete else actions_num * 2, - 400, - 5, - layer_norm=layer_norm, - intermediate_activation=nn.ELU, - final_activation=DistLayer('onehot' if self.is_discrete else 'normal_trunc')).to(device_type) - self.critic = ImaginativeCritic(discount_factor, critic_update_interval, - critic_soft_update_fraction, - critic_value_target_lambda, - slots_num*(rssm_dim + latent_dim * latent_classes), - layer_norm=layer_norm).to(device_type) + self.world_model: WorldModel = world_model(actions_num=actions_num).to(device_type) + self.actor: ImaginativeActor = actor(latent_dim=self.world_model.state_size, + actions_num=actions_num, + is_discrete=self.is_discrete).to(device_type) + self.critic: ImaginativeCritic = critic(latent_dim=self.world_model.state_size).to(device_type) - self.scaler = torch.cuda.amp.GradScaler() - self.image_predictor_optimizer = torch.optim.AdamW(self.world_model.image_predictor.parameters(), - lr=world_model_lr, - eps=1e-5, - weight_decay=1e-6) + self.world_model_optimizer = wm_optim(model=self.world_model) + self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor) + self.actor_optimizer = actor_optim(model=self.actor) + self.critic_optimizer = critic_optim(model=self.critic) - self.world_model_optimizer = torch.optim.AdamW(self.world_model.parameters(), - lr=world_model_lr, - eps=1e-5, - weight_decay=1e-6) - - warmup_steps = 1e3 - decay_rate = 0.5 - decay_steps = 5e5 - lr_warmup_scheduler = torch.optim.lr_scheduler.LinearLR(self.world_model_optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) - lr_decay_scheduler = torch.optim.lr_scheduler.LambdaLR(self.world_model_optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) - # lr_decay_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=decay_rate**(1/decay_steps)) - self.lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_warmup_scheduler, lr_decay_scheduler]) - - self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(), - lr=actor_lr, - eps=1e-5, - weight_decay=1e-6) - self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), - lr=critic_lr, - eps=1e-5, - weight_decay=1e-6) self.reset() def imagine_trajectory( @@ -943,36 +249,14 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=False): - losses, discovered_states, wm_metrics = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags) + losses_wm, discovered_states, metrics_wm = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags) self.world_model.recurrent_model.discretizer_scheduler.step() - # NOTE: 'aten::nonzero' inside KL divergence is not currently supported on M1 Pro MPS device - image_predictor_loss = losses['loss_reconstruction_img'] - world_model_loss = (losses['loss_reconstruction'] + - losses['loss_reward_pred'] + - losses['loss_kl_reg'] + - losses['loss_discount_pred']) - # for l in losses.values(): - # world_model_loss += l if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: - self.image_predictor_optimizer.zero_grad(set_to_none=True) - self.scaler.scale(image_predictor_loss).backward() - self.scaler.unscale_(self.image_predictor_optimizer) - nn.utils.clip_grad_norm_(self.world_model.image_predictor.parameters(), 100) - self.scaler.step(self.image_predictor_optimizer) + self.image_predictor_optimizer.step(losses_wm['loss_reconstruction_img']) - self.world_model_optimizer.zero_grad(set_to_none=True) - self.scaler.scale(world_model_loss).backward() - # FIXME: clip gradient should be parametrized - self.scaler.unscale_(self.world_model_optimizer) - # for tag, value in self.world_model.named_parameters(): - # wm_metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() - nn.utils.clip_grad_norm_(self.world_model.parameters(), 100) - self.scaler.step(self.world_model_optimizer) - self.lr_scheduler.step() - - metrics = wm_metrics + metrics_wm |= self.world_model_optimizer.step(losses_wm['loss_wm']) with torch.cuda.amp.autocast(enabled=False): losses_ac = {} @@ -993,61 +277,28 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # Ignore all factors after first is_finished state discount_factors = torch.cumprod(discount_factors, dim=0) - predicted_vs_dist = self.critic.estimate_value(zs[:-1].detach()) - losses_ac['loss_critic'] = -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(2)*discount_factors[:-1]).mean() - - metrics['critic/avg_target_value'] = self.critic.target_critic(zs[1:]).mode.mean() - metrics['critic/avg_lambda_value'] = vs.mean() - metrics['critic/avg_predicted_value'] = predicted_vs_dist.mode.mean() + losses_c, metrics_c = self.critic.calculate_loss(zs[:-1], vs, discount_factors[:-1]) # last action should be ignored as it is not used to predict next state, thus no feedback # first value should be ignored as it is comes from replay buffer - action_dists = self.actor(zs[:-2].detach()) - baseline = self.critic.target_critic(zs[:-2]).mode - advantage = (vs[1:] - baseline).detach() - losses_ac['loss_actor_reinforce'] = -(self.rho * action_dists.log_prob(actions[1:-1].detach()).unsqueeze(2) * discount_factors[:-2] * advantage).mean() - losses_ac['loss_actor_dynamics_backprop'] = -((1 - self.rho) * (vs[1:]*discount_factors[:-2])).mean() - - def calculate_entropy(dist): - return dist.entropy().unsqueeze(2) - # return dist.base_dist.base_dist.entropy().unsqueeze(2) - - losses_ac['loss_actor_entropy'] = -(self.eta * calculate_entropy(action_dists)*discount_factors[:-2]).mean() - losses_ac['loss_actor'] = losses_ac['loss_actor_reinforce'] + losses_ac['loss_actor_dynamics_backprop'] + losses_ac['loss_actor_entropy'] - - # mean and std are estimated statistically as tanh transformation is used - sample = action_dists.rsample((128,)) - act_avg = sample.mean(0) - metrics['actor/avg_val'] = act_avg.mean() - # metrics['actor/mode_val'] = action_dists.mode.mean() - metrics['actor/mean_val'] = action_dists.mean.mean() - metrics['actor/avg_sd'] = (((sample - act_avg)**2).mean(0).sqrt()).mean() - metrics['actor/min_val'] = sample.min() - metrics['actor/max_val'] = sample.max() - - self.actor_optimizer.zero_grad(set_to_none=True) - self.critic_optimizer.zero_grad(set_to_none=True) - - self.scaler.scale(losses_ac['loss_critic']).backward() - self.scaler.scale(losses_ac['loss_actor']).backward() - - self.scaler.unscale_(self.actor_optimizer) - self.scaler.unscale_(self.critic_optimizer) - nn.utils.clip_grad_norm_(self.actor.parameters(), 100) - nn.utils.clip_grad_norm_(self.critic.parameters(), 100) - - self.scaler.step(self.actor_optimizer) - self.scaler.step(self.critic_optimizer) + losses_a, metrics_a = self.actor.calculate_loss(zs[:-2], + vs[1:], + self.critic.target_critic(zs[:-2]).mode, + discount_factors[:-2], + actions[1:-1]) + metrics_a |= self.actor_optimizer.step(losses_a['loss_actor']) + metrics_c |= self.critic_optimizer.step(losses_c['loss_critic']) self.critic.update_target() - self.scaler.update() + + losses = losses_wm | losses_a | losses_c + metrics = metrics_wm | metrics_a | metrics_c losses = {l: val.detach().cpu().numpy() for l, val in losses.items()} - losses_ac = {l: val.detach().cpu().numpy() for l, val in losses_ac.items()} metrics = {l: val.detach().cpu().numpy() for l, val in metrics.items()} losses['total'] = sum(losses.values()) - return losses | losses_ac | metrics + return losses | metrics def save_ckpt(self, epoch_num: int, losses: dict[str, float]): torch.save( diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 20c1718..86044aa 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -1,35 +1,75 @@ _target_: rl_sandbox.agents.DreamerV2 -layer_norm: true -# World model parameters -batch_cluster_size: 50 -latent_dim: 16 -latent_classes: 16 -rssm_dim: 40 -slots_num: 8 -kl_loss_scale: 32.0 -kl_loss_balancing: 0.8 -kl_loss_free_nats: 0.00 -world_model_lr: 3e-4 -world_model_predict_discount: false -# ActorCritic parameters -discount_factor: 0.999 imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 16 + latent_classes: 16 + rssm_dim: 40 + slots_num: 2 + kl_loss_scale: 2.0 + kl_loss_balancing: 0.8 + kl_free_nats: 0.05 + discrete_rssm: false + decode_vit: false + vit_l2_ratio: 1.0 + encode_vit: false + predict_discount: false + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 1e-4 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} -actor_lr: 8e-5 -# mixing of reinforce and maximizing value func -# for dm_control it is zero in Dreamer (Atari 1) -actor_reinforce_fraction: null -actor_entropy_scale: 1e-4 +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + #- _target_: rl_sandbox.utils.optimizer.DecayScheduler + # _partial_: true + # decay_rate: 0.5 + # decay_steps: 5e5 + lr: 3e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 -critic_lr: 8e-5 -# Lambda parameter for trainin deeper multi-step prediction -critic_value_target_lambda: 0.95 -critic_update_interval: 100 -# [0-1], 1 means hard update -critic_soft_update_fraction: 1 +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 -discrete_rssm: false -decode_vit: false -vit_l2_ratio: 1.0 -encode_vit: false +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/utils/logger.py b/rl_sandbox/utils/logger.py index 09f6f1c..a54f166 100644 --- a/rl_sandbox/utils/logger.py +++ b/rl_sandbox/utils/logger.py @@ -43,9 +43,9 @@ def log(self, losses: dict[str, t.Any], global_step: int, mode: str = 'train'): for loss_name, loss in losses.items(): if 'grad' in loss_name: if self.log_grads: - self.writer.add_histogram(f'train/{loss_name}', loss, global_step) + self.writer.add_histogram(f'{mode}/{loss_name}', loss, global_step) else: - self.writer.add_scalar(f'train/{loss_name}', loss.item(), global_step) + self.writer.add_scalar(f'{mode}/{loss_name}', loss.item(), global_step) def add_scalar(self, name: str, value: t.Any, global_step: int): self.writer.add_scalar(name, value, global_step) diff --git a/rl_sandbox/utils/optimizer.py b/rl_sandbox/utils/optimizer.py new file mode 100644 index 0000000..77a176f --- /dev/null +++ b/rl_sandbox/utils/optimizer.py @@ -0,0 +1,71 @@ +import typing as t +from collections.abc import Iterable +import torch +import numpy as np +from torch import nn +from torch.optim.lr_scheduler import LRScheduler +from torch.cuda.amp import GradScaler + +from torch.optim.lr_scheduler import LinearLR, LambdaLR + +class WarmupScheduler(LinearLR): + def __init__(self, optimizer, warmup_steps): + super().__init__(optimizer, start_factor=1/warmup_steps, total_iters=int(warmup_steps)) + +# class WarmupScheduler(LambdaLR): +# def __init__(self, optimizer, warmup_steps): +# super().__init__(optimizer, lambda epoch: min(1, np.interp(epoch, [1, warmup_steps], [0, 1])) ) + +class DecayScheduler(LambdaLR): + def __init__(self, optimizer, decay_steps, decay_rate): + super().__init__(optimizer, lambda epoch: decay_rate**(epoch/decay_steps)) + +class Optimizer: + def __init__(self, model, + lr=1e-4, + eps=1e-8, + weight_decay=0.01, + lr_scheduler: t.Optional[t.Type[LRScheduler] | t.Iterable[t.Type[LRScheduler]]] = None, + scaler: bool = False, + log_grad: bool = False, + clip: t.Optional[float] = None): + self.model = model + self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=eps, weight_decay=weight_decay) + self.lr_scheduler = lr_scheduler + if lr_scheduler is not None and not isinstance(lr_scheduler, Iterable): + self.lr_scheduler = lr_scheduler(optimizer=self.optimizer) + elif isinstance(lr_scheduler, Iterable): + self.lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([lr_sched(optimizer=self.optimizer) for lr_sched in lr_scheduler]) + self.log_grad = log_grad + self.scaler = GradScaler() if scaler else None + self.clip = clip + + def step(self, loss): + metrics = {} + self.optimizer.zero_grad(set_to_none=True) + + if self.scaler: + loss = self.scaler.scale(loss) + loss.backward() + + if self.scaler: + self.scaler.unscale_(self.optimizer) + + if self.log_grad: + for tag, value in self.model.named_parameters(): + metrics[f"grad/{tag.replace('.', '/')}"] = value.detach() + + if self.clip: + nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) + + if self.scaler: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + if self.lr_scheduler: + self.lr_scheduler.step() + metrics[f'lr/{self.model.__class__.__name__}'] = torch.Tensor(self.lr_scheduler.get_last_lr()) + + return metrics From 01db4174e787e99452b8192e3c2d34fe93f83a98 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Thu, 27 Apr 2023 17:58:40 +0000 Subject: [PATCH 066/106] Refactored to have simultaneously multiple version of world models (dino/slotted) --- README.md | 7 +- rl_sandbox/agents/dreamer/ac.py | 7 + rl_sandbox/agents/dreamer/common.py | 50 +-- rl_sandbox/agents/dreamer/rssm.py | 150 +++++---- rl_sandbox/agents/dreamer/rssm_slots.py | 265 ++++++++++++++++ rl_sandbox/agents/dreamer/vision.py | 73 +++-- rl_sandbox/agents/dreamer/world_model.py | 155 ++++------ .../agents/dreamer/world_model_slots.py | 285 ++++++++++++++++++ rl_sandbox/agents/dreamer_v2.py | 123 +------- rl_sandbox/config/agent/dreamer_v2.yaml | 22 +- .../config/agent/dreamer_v2_slotted.yaml | 75 +++++ rl_sandbox/config/config.yaml | 23 +- rl_sandbox/config/config_dino.yaml | 43 +++ rl_sandbox/config/config_slotted.yaml | 46 +++ rl_sandbox/metrics.py | 228 +++++++++++++- rl_sandbox/train.py | 35 +-- 16 files changed, 1204 insertions(+), 383 deletions(-) create mode 100644 rl_sandbox/agents/dreamer/rssm_slots.py create mode 100644 rl_sandbox/agents/dreamer/world_model_slots.py create mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted.yaml create mode 100644 rl_sandbox/config/config_dino.yaml create mode 100644 rl_sandbox/config/config_slotted.yaml diff --git a/README.md b/README.md index bfa5b71..0f2851c 100644 --- a/README.md +++ b/README.md @@ -14,5 +14,10 @@ docker run --gpus 'all' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/ Run training inside docker on gpu 0: ```sh -docker run --gpus 'device=0' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer python3 rl_sandbox/train.py +docker run --gpus 'device=0' -it --rm -v `pwd`:/home/$USER/rl_sandbox -w /home/$USER/rl_sandbox dreamer python3 rl_sandbox/train.py --config-name config_dino +``` + +To run dreamer version with slot attention use: +``` +rl_sandbox/train.py --config-name config_slotted ``` diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py index 515e7e6..d774fda 100644 --- a/rl_sandbox/agents/dreamer/ac.py +++ b/rl_sandbox/agents/dreamer/ac.py @@ -103,6 +103,13 @@ def __init__(self, latent_dim: int, actions_num: int, is_discrete: bool, def forward(self, z: torch.Tensor) -> td.Distribution: return self.actor(z) + def get_action(self, state) -> td.Distribution: + # FIXME: you should be ashamed for such fix for prev_slots + if isinstance(state, tuple): + return self.actor(state[0].combined) + else: + return self.actor(state.combined) + def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, baseline: torch.Tensor, discount_factors: torch.Tensor, actions: torch.Tensor): losses = {} diff --git a/rl_sandbox/agents/dreamer/common.py b/rl_sandbox/agents/dreamer/common.py index 86ff51f..e3efcb2 100644 --- a/rl_sandbox/agents/dreamer/common.py +++ b/rl_sandbox/agents/dreamer/common.py @@ -1,10 +1,9 @@ import torch -import typing as t -from dataclasses import dataclass -from jaxtyping import Float, Bool from torch import nn + from rl_sandbox.utils.dists import DistLayer + class View(nn.Module): def __init__(self, shape): @@ -14,58 +13,25 @@ def __init__(self, shape): def forward(self, x): return x.view(*self.shape) + def Dist(val): return DistLayer('onehot')(val) -@dataclass -class State: - determ: Float[torch.Tensor, 'seq batch num_slots determ'] - stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] - stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None - - @property - def combined(self): - return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) - - @property - def combined_slots(self): - return torch.concat([self.determ, self.stoch], dim=-1) - - @property - def stoch(self): - if self.stoch_ is None: - self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1,)) - return self.stoch_ - - @property - def stoch_dist(self): - return Dist(self.stoch_logits) - - @classmethod - def stack(cls, states: list['State'], dim = 0): - if states[0].stoch_ is not None: - stochs = torch.cat([state.stoch for state in states], dim=dim) - else: - stochs = None - return State(torch.cat([state.determ for state in states], dim=dim), - torch.cat([state.stoch_logits for state in states], dim=dim), - stochs) - class Normalizer(nn.Module): + def __init__(self, momentum=0.99, scale=1.0, eps=1e-8): super().__init__() self.momentum = momentum self.scale = scale - self.eps= eps + self.eps = eps self.register_buffer('mag', torch.ones(1, dtype=torch.float32)) self.mag.requires_grad = False def forward(self, x): self.update(x) - return (x / (self.mag + self.eps))*self.scale + return (x / (self.mag + self.eps)) * self.scale def update(self, x): - self.mag = self.momentum * self.mag + (1 - self.momentum) * (x.abs().mean()).detach() - - + self.mag = self.momentum * self.mag + (1 - + self.momentum) * (x.abs().mean()).detach() diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py index f7de2f9..274d72d 100644 --- a/rl_sandbox/agents/dreamer/rssm.py +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -1,41 +1,81 @@ +import typing as t +from dataclasses import dataclass + import torch +from jaxtyping import Bool, Float from torch import nn from torch.nn import functional as F +from rl_sandbox.agents.dreamer import Dist, View from rl_sandbox.utils.schedulers import LinearScheduler -from rl_sandbox.agents.dreamer import State, View +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch determ'] + stoch_logits: Float[torch.Tensor, 'seq batch latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch stoch_dim']] = None + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist(self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:2] + (-1,)) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim = 0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs) + +# TODO: move to common class GRUCell(nn.Module): - def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): - super().__init__() - self._size = hidden_size - self._act = torch.tanh - self._norm = norm - self._update_bias = update_bias - self._layer = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=norm is not None, **kwargs) - if norm: - self._norm = nn.LayerNorm(3 * hidden_size) - - @property - def state_size(self): - return self._size - - def forward(self, x, h): - state = h - parts = self._layer(torch.concat([x, state], -1)) - if self._norm: - dtype = parts.dtype - parts = self._norm(parts.float()) - parts = parts.to(dtype=dtype) - reset, cand, update = parts.chunk(3, dim=-1) - reset = torch.sigmoid(reset) - cand = self._act(reset * cand) - update = torch.sigmoid(update + self._update_bias) - output = update * cand + (1 - update) * state - return output, output + + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, + 3 * hidden_size, + bias=norm is not None, + **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): super().__init__() @@ -45,8 +85,8 @@ def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): self.eps = eps embed = torch.randn(dim, n_embed) - self.inp_in = nn.Linear(1024, self.n_embed*self.dim) - self.inp_out = nn.Linear(self.n_embed*self.dim, 1024) + self.inp_in = nn.Linear(1024, self.n_embed * self.dim) + self.inp_out = nn.Linear(self.n_embed * self.dim, 1024) self.register_buffer("embed", embed) self.register_buffer("cluster_size", torch.zeros(n_embed)) self.register_buffer("embed_avg", embed.clone()) @@ -56,11 +96,8 @@ def forward(self, inp): input = inp.reshape(-1, 1, self.n_embed, self.dim) inp = input flatten = input.reshape(-1, self.dim) - dist = ( - flatten.pow(2).sum(1, keepdim=True) - - 2 * flatten @ self.embed - + self.embed.pow(2).sum(0, keepdim=True) - ) + dist = (flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True)) _, embed_ind = (-dist).max(1) embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) embed_ind = embed_ind.view(*input.shape[:-1]) @@ -70,20 +107,19 @@ def forward(self, inp): embed_onehot_sum = embed_onehot.sum(0) embed_sum = flatten.transpose(0, 1) @ embed_onehot - self.cluster_size.data.mul_(self.decay).add_( - embed_onehot_sum, alpha=1 - self.decay - ) + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, + alpha=1 - self.decay) self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) n = self.cluster_size.sum() - cluster_size = ( - (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n - ) + cluster_size = ((self.cluster_size + self.eps) / + (n + self.n_embed * self.eps) * n) embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) self.embed.data.copy_(embed_normalized) # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) quantize_out = quantize - diff = 0.25*(quantize_out.detach() - inp).pow(2).mean() + (quantize_out - inp.detach()).pow(2).mean() + diff = 0.25 * (quantize_out.detach() - inp).pow(2).mean() + ( + quantize_out - inp.detach()).pow(2).mean() quantize = inp + (quantize_out - inp).detach() return quantize, diff, embed_ind @@ -91,16 +127,13 @@ def forward(self, inp): def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.transpose(0, 1)) - class RSSM(nn.Module): """ Recurrent State Space Model - h_t <- deterministic state which is updated inside GRU s^_t <- stohastic discrete prior state (used for KL divergence: better predict future and encode smarter) s_t <- stohastic discrete posterior state (latent representation of current state) - h_1 ---> h_2 ---> h_3 ---> \\ x_1 \\ x_2 \\ x_3 | \\ | ^ | \\ | ^ | \\ | ^ @@ -112,10 +145,9 @@ class RSSM(nn.Module): v v | v v | v v | | | | s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| - """ - def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity, embed_size = 2*2*384): + def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity): super().__init__() self.latent_dim = latent_dim self.latent_classes = latent_classes @@ -146,9 +178,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret # For observation we do not have ensemble # FIXME: very bad magic number - # img_sz = 4 * 384 # 384x2x2 - # img_sz = 192 - img_sz = embed_size + img_sz = 4 * 384 # 384*2x2 self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' @@ -173,26 +203,24 @@ def estimate_stochastic_latent(self, prev_determ: torch.Tensor): def predict_next(self, prev_state: State, action) -> State: - x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1))], dim=-1)) + x = self.pre_determ_recurrent(torch.concat([prev_state.stoch, action], dim=-1)) # NOTE: x and determ are actually the same value if sequence of 1 is inserted - x, determ_prior = self.determ_recurrent(x.flatten(1, 2), prev_state.determ.flatten(1, 2)) + x, determ_prior = self.determ_recurrent(x, prev_state.determ) if self.discrete_rssm: - raise NotImplementedError("discrete rssm was not adopted for slot attention") - # determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) - # determ_post = determ_post.reshape(determ_prior.shape) - # determ_post = self.determ_layer_norm(determ_post) - # alpha = self.discretizer_scheduler.val - # determ_post = alpha * determ_prior + (1-alpha) * determ_post + determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + determ_post = determ_post.reshape(determ_prior.shape) + determ_post = self.determ_layer_norm(determ_post) + alpha = self.discretizer_scheduler.val + determ_post = alpha * determ_prior + (1-alpha) * determ_post else: determ_post, diff = determ_prior, 0 # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(x) - # Size is 1 x B x slots_num x ... - return State(determ_post.reshape(prev_state.determ.shape), predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff + return State(determ_post, predicted_stoch_logits), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' - return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten(1, 2).reshape(prior.stoch_logits.shape)) + return State(prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1))) def forward(self, h_prev: State, embed, action) -> tuple[State, State]: diff --git a/rl_sandbox/agents/dreamer/rssm_slots.py b/rl_sandbox/agents/dreamer/rssm_slots.py new file mode 100644 index 0000000..21d68b4 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm_slots.py @@ -0,0 +1,265 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, View +from rl_sandbox.utils.schedulers import LinearScheduler + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + + @property + def combined_slots(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist( + self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim=0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), stochs) + + +class GRUCell(nn.Module): + + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, + 3 * hidden_size, + bias=norm is not None, + **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + + +class Quantize(nn.Module): + + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + embed = torch.randn(dim, n_embed) + self.inp_in = nn.Linear(1024, self.n_embed * self.dim) + self.inp_out = nn.Linear(self.n_embed * self.dim, 1024) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, inp): + # input = self.inp_in(inp).reshape(-1, 1, self.n_embed, self.dim) + input = inp.reshape(-1, 1, self.n_embed, self.dim) + inp = input + flatten = input.reshape(-1, self.dim) + dist = (flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True)) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, + alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = ((self.cluster_size + self.eps) / + (n + self.n_embed * self.eps) * n) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) + quantize_out = quantize + diff = 0.25 * (quantize_out.detach() - inp).pow(2).mean() + ( + quantize_out - inp.detach()).pow(2).mean() + quantize = inp + (quantize_out - inp).detach() + + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, + latent_dim, + hidden_size, + actions_num, + latent_classes, + discrete_rssm, + norm_layer: nn.LayerNorm | nn.Identity, + embed_size=2 * 2 * 384): + super().__init__() + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True)) + self.determ_recurrent = GRUCell(input_size=hidden_size, + hidden_size=hidden_size, + norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + for _ in range(self.ensemble_num) + ]) + + # For observation we do not have ensemble + # FIXME: very bad magic number + # img_sz = 4 * 384 # 384x2x2 + # img_sz = 192 + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + # self.determ_discretizer = MlpVAE(self.hidden_size) + self.determ_discretizer = Quantize(16, 16) + self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) + self.determ_layer_norm = nn.LayerNorm(hidden_size) + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] + # NOTE: Maybe something smarter can be used instead of + # taking only one random between all ensembles + # NOTE: in Dreamer ensemble_num is always 1 + idx = torch.randint(0, self.ensemble_num, ()) + return dists_per_model[0] + + def predict_next(self, prev_state: State, action) -> State: + x = self.pre_determ_recurrent( + torch.concat([ + prev_state.stoch, + action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) + ], + dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(1, 2), + prev_state.determ.flatten(1, 2)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + # determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) + # determ_post = determ_post.reshape(determ_prior.shape) + # determ_post = self.determ_layer_norm(determ_post) + # alpha = self.discretizer_scheduler.val + # determ_post = alpha * determ_prior + (1-alpha) * determ_post + else: + determ_post, diff = determ_prior, 0 + + # used for KL divergence + predicted_stoch_logits = self.estimate_stochastic_latent(x) + # Size is 1 x B x slots_num x ... + return State(determ_post.reshape(prev_state.determ.shape), + predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State( + prior.determ, + self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( + 1, 2).reshape(prior.stoch_logits.shape)) + + def forward(self, h_prev: State, embed, action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index b04be2d..dcce3bf 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -1,24 +1,32 @@ import torch.distributions as td from torch import nn + class Encoder(nn.Module): - def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4]): + def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, + channel_step=96, + kernel_sizes=[4, 4, 4], + double_conv=False, + flatten_output=True, + ): super().__init__() layers = [] - channel_step = 96 in_channels = 3 for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) - layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) - layers.append(norm_layer(1, out_channels)) - layers.append(nn.ELU(inplace=True)) + if double_conv: + layers.append( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) in_channels = out_channels - # layers.append(nn.Flatten()) + if flatten_output: + layers.append(nn.Flatten()) self.net = nn.Sequential(*layers) def forward(self, X): @@ -27,24 +35,38 @@ def forward(self, X): class Decoder(nn.Module): - def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 6, 6]): + def __init__(self, + input_size, + norm_layer: nn.GroupNorm | nn.Identity, + kernel_sizes=[5, 5, 6, 6], + channel_step = 48, + output_channels=3, + return_dist=True): super().__init__() layers = [] - self.channel_step = 48 + self.channel_step = channel_step # 2**(len(kernel_sizes)-1)*channel_step self.convin = nn.Linear(input_size, 32 * self.channel_step) + self.return_dist = return_dist in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step for i, k in enumerate(kernel_sizes): out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step if i == len(kernel_sizes) - 1: out_channels = 3 - layers.append(nn.ConvTranspose2d(in_channels, 4, kernel_size=k, stride=2)) + layers.append(nn.ConvTranspose2d(in_channels, + output_channels, + kernel_size=k, + stride=2, + output_padding=0)) else: layers.append(norm_layer(1, in_channels)) layers.append( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, - stride=2)) + nn.ConvTranspose2d(in_channels, + out_channels, + kernel_size=k, + stride=2, + output_padding=0)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels self.net = nn.Sequential(*layers) @@ -52,14 +74,20 @@ def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_si def forward(self, X): x = self.convin(X) x = x.view(-1, 32 * self.channel_step, 1, 1) - return self.net(x) - # return td.Independent(td.Normal(self.net(x), 1.0), 3) + if self.return_dist: + return td.Independent(td.Normal(self.net(x), 1.0), 3) + else: + return self.net(x) + class ViTDecoder(nn.Module): # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 5, 3]): - def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 3]): + def __init__(self, + input_size, + norm_layer: nn.GroupNorm | nn.Identity, + kernel_sizes=[5, 5, 5, 3, 3]): super().__init__() layers = [] self.channel_step = 12 @@ -71,11 +99,21 @@ def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_si out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step if i == len(kernel_sizes) - 1: out_channels = 3 - layers.append(nn.ConvTranspose2d(in_channels, 384, kernel_size=k, stride=1, padding=1)) + layers.append( + nn.ConvTranspose2d(in_channels, + 384, + kernel_size=k, + stride=1, + padding=1)) else: layers.append(norm_layer(1, in_channels)) layers.append( - nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2, padding=2, output_padding=1)) + nn.ConvTranspose2d(in_channels, + out_channels, + kernel_size=k, + stride=2, + padding=2, + output_padding=1)) layers.append(nn.ELU(inplace=True)) in_channels = out_channels self.net = nn.Sequential(*layers) @@ -84,6 +122,3 @@ def forward(self, X): x = self.convin(X) x = x.view(-1, 32 * self.channel_step, 1, 1) return td.Independent(td.Normal(self.net(x), 1.0), 3) - - - diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 0f59e25..3ad2a4b 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -1,26 +1,24 @@ import typing as t + import torch import torch.distributions as td +import torchvision as tv from torch import nn from torch.nn import functional as F -import torchvision as tv -from rl_sandbox.vision.dino import ViTFeat - -from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer.rssm import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder from rl_sandbox.utils.dists import DistLayer -from rl_sandbox.vision.slot_attention import SlotAttention, PositionalEmbedding - -from rl_sandbox.agents.dreamer import Dist, State, Normalizer -from rl_sandbox.agents.dreamer.rssm import RSSM -from rl_sandbox.agents.dreamer.vision import Encoder, Decoder, ViTDecoder +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, - slots_num: int): + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float): super().__init__() self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.kl_beta = kl_loss_scale @@ -28,8 +26,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.rssm_dim = rssm_dim self.latent_dim = latent_dim self.latent_classes = latent_classes - self.slots_num = slots_num - self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + self.state_size = (rssm_dim + latent_dim * latent_classes) self.cluster_size = batch_cluster_size self.actions_num = actions_num @@ -40,19 +37,16 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio - self.n_dim = 384 - self.recurrent_model = RSSM(latent_dim, rssm_dim, actions_num, latent_classes, discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm, - embed_size=self.n_dim) + norm_layer=nn.Identity if layer_norm else nn.LayerNorm) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + # self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) @@ -69,40 +63,37 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # layer_norm=layer_norm) ) else: - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - - self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) - # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) - - self.slot_mlp = nn.Sequential( - nn.Linear(self.n_dim, self.n_dim), - nn.ReLU(inplace=True), - nn.Linear(self.n_dim, self.n_dim) - ) + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + kernel_sizes=[4, 4, 4, 4], + channel_step=48, + double_conv=False) if decode_vit: - self.dino_predictor = ViTDecoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + self.dino_predictor = Decoder(self.state_size, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + channel_step=192, + kernel_sizes=[3, 4], + output_channels=self.vit_feat_dim, + return_dist=True) + # self.dino_predictor = fc_nn_generator(self.state_size, # 64*self.dino_vit.feat_dim, # hidden_size=2048, # num_layers=5, # intermediate_activation=nn.ELU, # layer_norm=layer_norm, # final_activation=DistLayer('mse')) - self.image_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + self.image_predictor = Decoder(self.state_size, norm_layer=nn.Identity if layer_norm else nn.GroupNorm) - self.reward_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), + self.reward_predictor = fc_nn_generator(self.state_size, 1, hidden_size=400, num_layers=5, intermediate_activation=nn.ELU, layer_norm=layer_norm, final_activation=DistLayer('mse')) - self.discount_predictor = fc_nn_generator(slots_num*(rssm_dim + latent_dim * latent_classes), + self.discount_predictor = fc_nn_generator(self.state_size, 1, hidden_size=400, num_layers=5, @@ -113,9 +104,9 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device - return State(torch.zeros(seq_size, batch_size, self.slots_num, self.rssm_dim, device=device), - torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes, self.latent_dim, device=device), - torch.zeros(seq_size, batch_size, self.slots_num, self.latent_classes * self.latent_dim, device=device)) + return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) @@ -127,29 +118,20 @@ def predict_next(self, prev_state: State, action): discount_factors = torch.ones_like(reward) return prior, reward, discount_factors - def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State], prev_slots: t.Optional[torch.Tensor]) -> t.Tuple[State, torch.Tensor]: + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: if state is None: state = self.get_initial_state() embed = self.encoder(obs.unsqueeze(0)) - embed_with_pos_enc = self.positional_augmenter_inp(embed) - - pre_slot_features_t = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) - - slots_t = self.slot_attention(pre_slot_features_t, prev_slots) - - _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) - return posterior, slots_t + _, posterior, _ = self.recurrent_model.forward(state, embed.unsqueeze(0), + action) + return posterior def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) - embed_with_pos_enc = self.positional_augmenter_inp(embed) - # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) - - pre_slot_features = self.slot_mlp(embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) - pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, self.cluster_size, -1, self.n_dim) + embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) a_c = a.reshape(-1, self.cluster_size, self.actions_num) r_c = r.reshape(-1, self.cluster_size, 1) @@ -159,12 +141,16 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, losses = {} metrics = {} - def KL(dist1, dist2): + def KL(dist1, dist2, free_nat = True): KL_ = torch.distributions.kl_divergence - kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), td.OneHotCategoricalStraightThrough(logits=dist1)).mean() - kl_rhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2), td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() - kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) - kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) + # TODO: kl_free_avg is used always + if free_nat: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) + else: + kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() + kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) priors = [] @@ -173,25 +159,21 @@ def KL(dist1, dist2): if self.decode_vit: inp = obs if not self.encode_vit: - ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)), - tv.transforms.Resize(224, antialias=True)]) - # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)) + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) inp = ToTensor(obs + 0.5) d_features = self.dino_vit(inp) prev_state = self.get_initial_state(b // self.cluster_size) - prev_slots = None for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, t], a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) a_t = a_t * (1 - first_t) - slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) - # prev_slots = None - - prior, posterior, diff = self.recurrent_model.forward(prev_state, slots_t.unsqueeze(0), a_t) + prior, posterior, diff = self.recurrent_model.forward(prev_state, embed_t, a_t) prev_state = posterior priors.append(prior) @@ -208,23 +190,20 @@ def KL(dist1, dist2): losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) if not self.decode_vit: - decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * img_mask - x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) - + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() else: - raise NotImplementedError("") - # if self.vit_l2_ratio != 1.0: - # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) - # img_rec = -x_r.log_prob(obs).float().mean() - # else: - # img_rec = 0 - # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) - # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() - # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) - # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + + if self.vit_l2_ratio != 1.0: + x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = 0 + x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + + (1-self.vit_l2_ratio) * img_rec) + # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.flatten(1, 2)).float().mean() + # (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits @@ -239,11 +218,7 @@ def KL(dist1, dist2): metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() - losses['loss_wm'] = (losses['loss_reconstruction'] + - losses['loss_reward_pred'] + - losses['loss_kl_reg'] + - losses['loss_discount_pred']) + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + losses['loss_kl_reg'] + losses['loss_discount_pred']) return losses, posterior, metrics - - diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py new file mode 100644 index 0000000..c91ba82 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -0,0 +1,285 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer.rssm_slots import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, + discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, + decode_vit: bool, vit_l2_ratio: float, slots_num: int): + super().__init__() + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + + self.n_dim = 384 + + self.recurrent_model = RSSM( + latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + embed_size=self.n_dim) + if encode_vit or decode_vit: + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + self.dino_vit = ViTFeat( + "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, + vit_arch='small', + patch_size=16) + self.vit_feat_dim = self.dino_vit.feat_dim + self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.encoder = nn.Sequential( + self.dino_vit, + nn.Flatten(), + # fc_nn_generator(64*self.dino_vit.feat_dim, + # 64*384, + # hidden_size=400, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm) + ) + else: + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + kernel_sizes=[4, 4, 4], + channel_step=96, + double_conv=True, + flatten_output=False) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + + self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim)) + + if decode_vit: + self.dino_predictor = ViTDecoder( + self.state_size, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + # 64*self.dino_vit.feat_dim, + # hidden_size=2048, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm, + # final_activation=DistLayer('mse')) + self.image_predictor = Decoder( + rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + output_channels=4, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + # Tuple of State-Space state and prev slots + return State( + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.rssm_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes, + self.latent_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes * self.latent_dim, + device=device)), None + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).sample() + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: + if state is None or state[0] is None: + state, prev_slots = self.get_initial_state() + else: + state, prev_slots = state + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), + action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor): + b, _, h, w = obs.shape # s <- BxHxWx3 + + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + pre_slot_features = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, + self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), + td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_rhs = KL_( + td.OneHotCategoricalStraightThrough(logits=dist2), + td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + inp = obs + if not self.encode_vit: + ToTensor = tv.transforms.Compose([ + tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + tv.transforms.Resize(224, antialias=True) + ]) + # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)) + inp = ToTensor(obs + 0.5) + d_features = self.dino_vit(inp) + + prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, + t], a_c[:, t].unsqueeze( + 0 + ), first_c[:, + t].unsqueeze( + 0) + a_t = a_t * (1 - first_t) + + slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) + # prev_slots = None + + prior, posterior, diff = self.recurrent_model.forward( + prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + raise NotImplementedError("") + # if self.vit_l2_ratio != 1.0: + # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) + # img_rec = -x_r.log_prob(obs).float().mean() + # else: + # img_rec = 0 + # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) + # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) + # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + + # (1-self.vit_l2_ratio) * img_rec) + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + losses['loss_kl_reg'] + losses['loss_discount_pred']) + + return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 943db4a..c099281 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -1,7 +1,6 @@ import typing as t from pathlib import Path -import matplotlib.pyplot as plt import numpy as np import torch from torch import nn @@ -15,8 +14,7 @@ TerminationFlags, IsFirstFlags) from rl_sandbox.utils.optimizer import Optimizer -from rl_sandbox.agents.dreamer import State -from rl_sandbox.agents.dreamer.world_model import WorldModel +from rl_sandbox.agents.dreamer.world_model import WorldModel, State from rl_sandbox.agents.dreamer.ac import ImaginativeCritic, ImaginativeActor @@ -86,15 +84,12 @@ def imagine_trajectory( ts.append(discount) actions.append(a) - return (State.stack(states), torch.cat(actions), torch.cat(rewards), torch.cat(ts)) + return (states[0].stack(states), torch.cat(actions), torch.cat(rewards), torch.cat(ts)) def reset(self): self._state = self.world_model.get_initial_state() - self._prev_slots = None self._last_action = torch.zeros((1, 1, self.actions_num), device=self.device) - self._latent_probs = torch.zeros((self.world_model.latent_classes, self.world_model.latent_dim), device=self.device) self._action_probs = torch.zeros((self.actions_num), device=self.device) - self._stored_steps = 0 def preprocess_obs(self, obs: torch.Tensor): # FIXME: move to dataloader in replay buffer @@ -111,127 +106,19 @@ def preprocess_obs(self, obs: torch.Tensor): def get_action(self, obs: Observation) -> Action: # NOTE: pytorch fails without .copy() only when get_action is called - # FIXME: return back action selection obs = torch.from_numpy(obs.copy()).to(self.device) obs = self.preprocess_obs(obs) - self._state, self._prev_slots = self.world_model.get_latent(obs, self._last_action, self._state, self._prev_slots) + self._state = self.world_model.get_latent(obs, self._last_action, self._state) - actor_dist = self.actor(self._state.combined) + actor_dist = self.actor.get_action(self._state) self._last_action = actor_dist.sample() - if self.is_discrete: - self._action_probs += actor_dist.probs.squeeze().mean(dim=0) - self._latent_probs += self._state.stoch_dist.probs.squeeze().mean(dim=0) - self._stored_steps += 1 - if self.is_discrete: return self._last_action.squeeze().detach().cpu().numpy().argmax() else: return self._last_action.squeeze().detach().cpu().numpy() - def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): - obs = torch.from_numpy(obs.copy()).to(self.device) - obs = self.preprocess_obs(obs) - actions = self.from_np(actions) - if self.is_discrete: - actions = F.one_hot(actions.to(torch.int64), num_classes=self.actions_num).squeeze() - video = [] - slots_video = [] - rews = [] - - state = None - prev_slots = None - means = np.array([0.485, 0.456, 0.406]) - stds = np.array([0.229, 0.224, 0.225]) - UnNormalize = tv.transforms.Normalize(list(-means/stds), - list(1/stds)) - for idx, (o, a) in enumerate(list(zip(obs, actions))): - if idx > update_num: - break - state, prev_slots = self.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state, prev_slots) - # video_r = self.world_model.image_predictor(state.combined_slots).mode.cpu().detach().numpy() - - decoded_imgs, masks = self.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) - # TODO: try the scaling of softmax as in attention - img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * img_mask - video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() - - rews.append(self.world_model.reward_predictor(state.combined).mode.item()) - if self.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() - else: - video_r = (video_r + 0.5) - video.append(video_r) - slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) - - rews = torch.Tensor(rews).to(obs.device) - - if update_num < len(obs): - states, _, rews_2, _ = self.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.imagination_horizon - 1 - update_num) - rews = torch.cat([rews, rews_2[1:].squeeze()]) - - # video_r = self.world_model.image_predictor(states.combined_slots[1:]).mode.cpu().detach().numpy() - decoded_imgs, masks = self.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * img_mask - video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() - - if self.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() - else: - video_r = (video_r + 0.5) - video.append(video_r) - slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) - - return np.concatenate(video), rews, np.concatenate(slots_video) - - def viz_log(self, rollout, logger, epoch_num): - init_indeces = np.random.choice(len(rollout.states) - self.imagination_horizon, 5) - - videos = np.concatenate([ - rollout.next_states[init_idx:init_idx + self.imagination_horizon].transpose( - 0, 3, 1, 2) for init_idx in init_indeces - ], axis=3).astype(np.float32) / 255.0 - - real_rewards = [rollout.rewards[idx:idx+ self.imagination_horizon] for idx in init_indeces] - - videos_r, imagined_rewards, slots_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.imagination_horizon//3) for obs_0, a_0 in zip( - [rollout.next_states[idx:idx+ self.imagination_horizon] for idx in init_indeces], - [rollout.actions[idx:idx+ self.imagination_horizon] for idx in init_indeces]) - ]) - videos_r = np.concatenate(videos_r, axis=3) - - slots_video = np.concatenate(list(slots_video)[:3], axis=3) - slots_video = slots_video.transpose((0, 2, 3, 1, 4)) - slots_video = np.expand_dims(slots_video.reshape(*slots_video.shape[:-2], -1), 0) - - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) - videos_comparison = (videos_comparison * 255.0).astype(np.uint8) - latent_hist = (self._latent_probs / self._stored_steps).detach().cpu().numpy() - latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) - - # if discrete action space - if self.is_discrete: - action_hist = (self._action_probs / self._stored_steps).detach().cpu().numpy() - fig = plt.Figure() - ax = fig.add_axes([0, 0, 1, 1]) - ax.bar(np.arange(self.actions_num), action_hist) - logger.add_figure('val/action_probs', fig, epoch_num) - else: - # log mean +- std - pass - logger.add_image('val/latent_probs', latent_hist, epoch_num, dataformats='HW') - logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), epoch_num, dataformats='HW') - logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) - logger.add_video('val/dreamed_slots', slots_video, epoch_num) - - rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() - logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) - - logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) - def from_np(self, arr: np.ndarray): arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr return arr.to(self.device, non_blocking=True) @@ -260,7 +147,7 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation with torch.cuda.amp.autocast(enabled=False): losses_ac = {} - initial_states = State(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), + initial_states = discovered_states.__class__(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_logits.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_.flatten(0, 1).unsqueeze(0).detach()) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 86044aa..fdc0129 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -2,21 +2,20 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 batch_cluster_size: 50 -layer_norm: true +layer_norm: false world_model: _target_: rl_sandbox.agents.dreamer.world_model.WorldModel _partial_: true batch_cluster_size: ${..batch_cluster_size} - latent_dim: 16 - latent_classes: 16 - rssm_dim: 40 - slots_num: 2 - kl_loss_scale: 2.0 + latent_dim: 32 + latent_classes: 32 + rssm_dim: 200 + kl_loss_scale: 1.0 kl_loss_balancing: 0.8 kl_free_nats: 0.05 discrete_rssm: false - decode_vit: false + decode_vit: true vit_l2_ratio: 1.0 encode_vit: false predict_discount: false @@ -45,14 +44,7 @@ critic: wm_optim: _target_: rl_sandbox.utils.optimizer.Optimizer _partial_: true - lr_scheduler: - - _target_: rl_sandbox.utils.optimizer.WarmupScheduler - _partial_: true - warmup_steps: 1e3 - #- _target_: rl_sandbox.utils.optimizer.DecayScheduler - # _partial_: true - # decay_rate: 0.5 - # decay_steps: 5e5 + lr_scheduler: null lr: 3e-4 eps: 1e-5 weight_decay: 1e-6 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted.yaml new file mode 100644 index 0000000..f395d45 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted.yaml @@ -0,0 +1,75 @@ +_target_: rl_sandbox.agents.DreamerV2 + +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 16 + latent_classes: 16 + rssm_dim: 40 + slots_num: 2 + kl_loss_scale: 2.0 + kl_loss_balancing: 0.8 + kl_free_nats: 0.05 + discrete_rssm: false + decode_vit: false + vit_l2_ratio: 1.0 + encode_vit: false + predict_discount: false + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 1e-4 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + #- _target_: rl_sandbox.utils.optimizer.DecayScheduler + # _partial_: true + # decay_rate: 0.5 + # decay_steps: 5e5 + lr: 3e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 0737b54..eee688a 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -9,26 +9,25 @@ seed: 42 device_type: cuda logger: - type: tensorboard - message: 32 KL, Cartpole 8 slots, double encoder, reduced warmup, 384 n_dim, correct prev_slots, 0.00 nats, 40 rssm dims, 16x16 stoch - #message: test_last + type: null + message: Dreamer with DINO features log_grads: false training: checkpoint_path: null steps: 1e6 - #prefill: 0 - #pretrain: 0 - batch_size: 16 - val_logs_every: 4e3 + val_logs_every: 1e3 validation: rollout_num: 5 visualize: true - -#agents: -# batch_cluster_size: 10 + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true debug: profiler: false @@ -37,8 +36,8 @@ hydra: #mode: MULTIRUN mode: RUN launcher: - #n_jobs: 8 + #n_jobs: 3 n_jobs: 1 sweeper: params: - #agent.kl_loss_scale: 1.0,4.0,16.0,32.0,64.0,96.0,128.0 + #agent.kl_loss_scale: 64.0,128.0,256.0 diff --git a/rl_sandbox/config/config_dino.yaml b/rl_sandbox/config/config_dino.yaml new file mode 100644 index 0000000..eee688a --- /dev/null +++ b/rl_sandbox/config/config_dino.yaml @@ -0,0 +1,43 @@ +defaults: + - agent: dreamer_v2 + - env: dm_cartpole + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: null + message: Dreamer with DINO features + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 1e3 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 3 + n_jobs: 1 + sweeper: + params: + #agent.kl_loss_scale: 64.0,128.0,256.0 diff --git a/rl_sandbox/config/config_slotted.yaml b/rl_sandbox/config/config_slotted.yaml new file mode 100644 index 0000000..7230c59 --- /dev/null +++ b/rl_sandbox/config/config_slotted.yaml @@ -0,0 +1,46 @@ +defaults: + - agent: dreamer_v2_slotted + - env: dm_cartpole + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: null + message: Cartpole 2 slots, double encoder, reduced warmup, 384 n_dim, correct prev_slots, 0.05 nats, 40 rssm dims, 16x16 stoch, 2x KL + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + prefill: 0 + pretrain: 1 + batch_size: 2 + val_logs_every: 1e3 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 3 + n_jobs: 1 + sweeper: + params: + #agent.kl_loss_scale: 64.0,128.0,256.0 diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 07a20a9..39c259c 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -1,9 +1,35 @@ import numpy as np +import matplotlib.pyplot as plt +import torchvision as tv +from torch.nn import functional as F +import torch from rl_sandbox.utils.replay_buffer import Rollout +from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, + Observations, Rewards, + TerminationFlags, IsFirstFlags) -class MetricsEvaluator(): +class EpisodeMetricsEvaluator(): + def __init__(self, agent: 'DreamerV2', log_video: bool = False): + self.agent = agent + self.episode = 0 + self.log_video = log_video + + def on_step(self): + pass + + def on_episode(self, logger, rollout): + pass + + def on_val(self, logger, rollouts: list[Rollout]): + metrics = self.calculate_metrics(rollouts) + logger.log(metrics, self.episode, mode='val') + if self.log_video: + video = np.expand_dims(rollouts[0].observations.transpose(0, 3, 1, 2), 0) + logger.add_video('val/visualization', video, self.episode) + self.episode += 1 + def calculate_metrics(self, rollouts: list[Rollout]): return { 'episode_len': self._episode_duration(rollouts), @@ -15,3 +41,203 @@ def _episode_duration(self, rollouts: list[Rollout]): def _episode_return(self, rollouts: list[Rollout]): return np.mean(list(map(lambda x: sum(x.rewards), rollouts))) + +class DreamerMetricsEvaluator(): + def __init__(self, agent: 'DreamerV2'): + self.agent = agent + self.stored_steps = 0 + self.episode = 0 + + if agent.is_discrete: + pass + + def reset_ep(self): + self._latent_probs = torch.zeros((self.agent.world_model.latent_classes, self.agent.world_model.latent_dim), device=agent.device) + self._action_probs = torch.zeros((self.agent.actions_num), device=self.agent.device) + self.stored_steps = 0 + + def on_step(self): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.probs.squeeze().mean(dim=0) + + def on_episode(self, logger, rollout): + latent_hist = (self.agent._latent_probs / self._stored_steps).detach().cpu().numpy() + latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) + + # if discrete action space + if self.agent.is_discrete: + action_hist = (self.agent._action_probs / self._stored_steps).detach().cpu().numpy() + fig = plt.Figure() + ax = fig.add_axes([0, 0, 1, 1]) + ax.bar(np.arange(self.agent.actions_num), action_hist) + logger.add_figure('val/action_probs', fig, self.episode) + else: + # log mean +- std + pass + logger.add_image('val/latent_probs', latent_hist, self.episode, dataformats='HW') + logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), self.episode, dataformats='HW') + + self.reset_ep() + self.episode += 1 + + def on_val(self, logger, rollouts: list[Rollout]): + self.viz_log(rollouts[0], logger, self.episode) + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + obs = torch.from_numpy(obs.copy()).to(self.agent.device) + obs = self.agent.preprocess_obs(obs) + actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + rews = [] + + state = None + means = np.array([0.485, 0.456, 0.406]) + stds = np.array([0.229, 0.224, 0.225]) + UnNormalize = tv.transforms.Normalize(list(-means/stds), + list(1/stds)) + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) + video_r = self.agent.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + if self.agent.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) + video.append(video_r) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + video_r = self.agent.world_model.image_predictor(states.combined[1:]).mode.cpu().detach().numpy() + if self.agent.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) + video.append(video_r) + + return np.concatenate(video), rews + + def viz_log(self, rollout, logger, epoch_num): + init_indeces = np.random.choice(len(rollout.states) - self.agent.imagination_horizon, 5) + + videos = np.concatenate([ + rollout.next_states[init_idx:init_idx + self.agent.imagination_horizon].transpose( + 0, 3, 1, 2) for init_idx in init_indeces + ], axis=3).astype(np.float32) / 255.0 + + real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] + + videos_r, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.next_states[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + ]) + videos_r = np.concatenate(videos_r, axis=3) + + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) + videos_comparison = (videos_comparison * 255.0).astype(np.uint8) + + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) + + +class SlottedDreamerMetricsEvaluator(DreamerMetricsEvaluator): + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + obs = torch.from_numpy(obs.copy()).to(self.agent.device) + obs = self.agent.preprocess_obs(obs) + actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + rews = [] + + state = None + prev_slots = None + means = np.array([0.485, 0.456, 0.406]) + stds = np.array([0.229, 0.224, 0.225]) + UnNormalize = tv.transforms.Normalize(list(-means/stds), + list(1/stds)) + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode.cpu().detach().numpy() + + decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + if self.agent.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) + video.append(video_r) + slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode.cpu().detach().numpy() + decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + + if self.agent.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) + video.append(video_r) + slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) + + return np.concatenate(video), rews, np.concatenate(slots_video) + + def viz_log(self, rollout, logger, epoch_num): + init_indeces = np.random.choice(len(rollout.states) - self.agent.imagination_horizon, 5) + + videos = np.concatenate([ + rollout.next_states[init_idx:init_idx + self.agent.imagination_horizon].transpose( + 0, 3, 1, 2) for init_idx in init_indeces + ], axis=3).astype(np.float32) / 255.0 + + real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] + + videos_r, imagined_rewards, slots_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.next_states[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + ]) + videos_r = np.concatenate(videos_r, axis=3) + + slots_video = np.concatenate(list(slots_video)[:3], axis=3) + slots_video = slots_video.transpose((0, 2, 3, 1, 4)) + slots_video = np.expand_dims(slots_video.reshape(*slots_video.shape[:-2], -1), 0) + + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) + videos_comparison = (videos_comparison * 255.0).astype(np.uint8) + + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + logger.add_video('val/dreamed_slots', slots_video, epoch_num) + + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 04532ba..96db93e 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -14,7 +14,7 @@ from torch.profiler import ProfilerActivity, profile from tqdm import tqdm -from rl_sandbox.metrics import MetricsEvaluator +from rl_sandbox.metrics import EpisodeMetricsEvaluator from rl_sandbox.utils.env import Env from rl_sandbox.utils.logger import Logger from rl_sandbox.utils.replay_buffer import ReplayBuffer @@ -23,29 +23,17 @@ iter_rollout) -def val_logs(agent, val_cfg: DictConfig, env: Env, global_step: int, logger: Logger): +def val_logs(agent, val_cfg: DictConfig, metrics, env: Env, logger: Logger): with torch.no_grad(): - rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent) - # TODO: make logs visualization in separate process - # Possibly make the data loader - metrics = MetricsEvaluator().calculate_metrics(rollouts) - logger.log(metrics, global_step, mode='val') + rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent, collect_obs=True) - if val_cfg.visualize: - rollouts = collect_rollout_num(env, 1, agent, collect_obs=True) + for metric in metrics: + metric.on_val(logger, rollouts) - for rollout in rollouts: - video = np.expand_dims(rollout.observations.transpose(0, 3, 1, 2), 0) - logger.add_video('val/visualization', video, global_step) - # FIXME: Very bad from architecture point - with torch.no_grad(): - agent.viz_log(rollout, logger, global_step) - -@hydra.main(version_base="1.2", config_path='config', config_name='config') +@hydra.main(version_base="1.2", config_path='config', config_name='config_slotted') def main(cfg: DictConfig): lt.monkey_patch() - # print(OmegaConf.to_yaml(cfg)) torch.distributions.Distribution.set_default_validate_args(False) eval('setattr(torch.backends.cudnn, "benchmark", True)') # need to be pickable for multirun torch.backends.cuda.matmul.allow_tf32 = True @@ -64,6 +52,7 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) + print(f'Start run: {cfg.logger.message}') logger = Logger(**cfg.logger) env: Env = hydra.utils.instantiate(cfg.env) @@ -93,6 +82,8 @@ def main(cfg: DictConfig): device_type=cfg.device_type, logger=logger) + metrics = [metric(agent) for metric in hydra.utils.instantiate(cfg.validation.metrics)] + prof = profile( activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], on_trace_ready=torch.profiler.tensorboard_trace_handler(logger.log_dir() + '/profiler'), @@ -108,11 +99,7 @@ def main(cfg: DictConfig): losses = agent.train(s, a, r, n, f, first) logger.log(losses, i, mode='pre_train') - # TODO: remove constants - # log_every_n = 25 - # if i % log_every_n == 0: - # st = int(cfg.training.pretrain) // log_every_n - val_logs(agent, cfg.validation, val_env, -1, logger) + val_logs(agent, cfg.validation, metrics, val_env, logger) if cfg.training.checkpoint_path is not None: prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) @@ -146,7 +133,7 @@ def main(cfg: DictConfig): ### Validation if (global_step % cfg.training.val_logs_every) <= (prev_global_step % cfg.training.val_logs_every): - val_logs(agent, cfg.validation, val_env, global_step, logger) + val_logs(agent, cfg.validation, metrics, val_env, logger) ### Checkpoint if (global_step % cfg.training.save_checkpoint_every) < ( From 6b1f4b515e4b8a9434e25d082cc8b8817d69ae90 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 3 Jun 2023 22:51:01 +0100 Subject: [PATCH 067/106] Get slotted crafter working with DINO features --- rl_sandbox/agents/dreamer/world_model.py | 2 +- .../agents/dreamer/world_model_slots.py | 31 +- .../agents/dreamer/world_model_slots_dino.py | 306 ++++++++++++++++++ rl_sandbox/agents/dreamer_v2.py | 3 + rl_sandbox/config/agent/dreamer_v2.yaml | 8 +- .../config/agent/dreamer_v2_crafter.yaml | 88 +++-- .../agent/dreamer_v2_crafter_slotted.yaml | 73 +++++ .../config/agent/dreamer_v2_slotted.yaml | 10 +- .../config/agent/dreamer_v2_slotted_dino.yaml | 73 +++++ rl_sandbox/config/config.yaml | 25 +- rl_sandbox/config/config_dino.yaml | 14 +- rl_sandbox/config/config_dino_1.yaml | 43 +++ rl_sandbox/config/config_dino_2.yaml | 44 +++ rl_sandbox/config/config_slotted.yaml | 22 +- rl_sandbox/config/config_slotted_debug.yaml | 44 +++ rl_sandbox/metrics.py | 136 +++++++- rl_sandbox/train.py | 8 +- rl_sandbox/vision/slot_attention.py | 4 +- 18 files changed, 845 insertions(+), 89 deletions(-) create mode 100644 rl_sandbox/agents/dreamer/world_model_slots_dino.py create mode 100644 rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml create mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml create mode 100644 rl_sandbox/config/config_dino_1.yaml create mode 100644 rl_sandbox/config/config_dino_2.yaml create mode 100644 rl_sandbox/config/config_slotted_debug.yaml diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 3ad2a4b..4347b69 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -8,7 +8,7 @@ from rl_sandbox.agents.dreamer import Dist, Normalizer from rl_sandbox.agents.dreamer.rssm import RSSM, State -from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder from rl_sandbox.utils.dists import DistLayer from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.vision.dino import ViTFeat diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index c91ba82..f8535ea 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -20,8 +20,9 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int): + decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True): super().__init__() + self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.kl_beta = kl_loss_scale @@ -80,7 +81,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, double_conv=True, flatten_output=False) - self.slot_attention = SlotAttention(slots_num, self.n_dim, 5) + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) @@ -89,9 +90,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, nn.Linear(self.n_dim, self.n_dim)) if decode_vit: - self.dino_predictor = ViTDecoder( - self.state_size, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + channel_step=192, + kernel_sizes=[3, 4], + output_channels=self.vit_feat_dim+1, + return_dist=True) # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, # 64*self.dino_vit.feat_dim, # hidden_size=2048, @@ -102,7 +106,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.Identity if layer_norm else nn.GroupNorm, - output_channels=4, + output_channels=3+1, return_dist=False) self.reward_predictor = fc_nn_generator(self.state_size, @@ -156,7 +160,10 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t if state is None or state[0] is None: state, prev_slots = self.get_initial_state() else: - state, prev_slots = state + if self.use_prev_slots: + state, prev_slots = state + else: + state, prev_slots = state[0], None embed = self.encoder(obs.unsqueeze(0)) embed_with_pos_enc = self.positional_augmenter_inp(embed) @@ -228,7 +235,11 @@ def KL(dist1, dist2): a_t = a_t * (1 - first_t) slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) - # prev_slots = None + # FIXME: prev_slots was not used properly, need to rerun test + if self.use_prev_slots: + prev_slots = slots_t + else: + prev_slots = None prior, posterior, diff = self.recurrent_model.forward( prev_state, slots_t.unsqueeze(0), a_t) @@ -260,8 +271,8 @@ def KL(dist1, dist2): # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) # img_rec = -x_r.log_prob(obs).float().mean() # else: - # img_rec = 0 - # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) + # img_rec = 0 + # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + diff --git a/rl_sandbox/agents/dreamer/world_model_slots_dino.py b/rl_sandbox/agents/dreamer/world_model_slots_dino.py new file mode 100644 index 0000000..1476440 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_slots_dino.py @@ -0,0 +1,306 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer.rssm_slots import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, + discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, + decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True): + super().__init__() + self.use_prev_slots = use_prev_slots + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + + self.n_dim = 384 + + self.recurrent_model = RSSM( + latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + embed_size=self.n_dim) + if encode_vit or decode_vit: + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + # self.dino_vit = ViTFeat( + # "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + # feat_dim=384, + # vit_arch='small', + # patch_size=16) + self.vit_feat_dim = self.dino_vit.feat_dim + self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.encoder = nn.Sequential( + self.dino_vit, + nn.Flatten(), + # fc_nn_generator(64*self.dino_vit.feat_dim, + # 64*384, + # hidden_size=400, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm) + ) + else: + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + kernel_sizes=[4, 4, 4], + channel_step=96, + double_conv=True, + flatten_output=False) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + + self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim)) + + if decode_vit: + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + channel_step=192, + # kernel_sizes=[5, 5, 4], # for size 224x224 + kernel_sizes=[3, 4], + output_channels=self.vit_feat_dim+1, + return_dist=False) + # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + # 64*self.dino_vit.feat_dim, + # hidden_size=2048, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm, + # final_activation=DistLayer('mse')) + self.image_predictor = Decoder( + rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + output_channels=3+1, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + # Tuple of State-Space state and prev slots + return State( + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.rssm_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes, + self.latent_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes * self.latent_dim, + device=device)), None + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).sample() + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: + if state is None or state[0] is None: + state, prev_slots = self.get_initial_state() + else: + if self.use_prev_slots: + state, prev_slots = state + else: + state, prev_slots = state[0], None + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), + action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor): + b, _, h, w = obs.shape # s <- BxHxWx3 + + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + pre_slot_features = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, + self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), + td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_rhs = KL_( + td.OneHotCategoricalStraightThrough(logits=dist2), + td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + inp = obs + if not self.encode_vit: + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + inp = ToTensor(obs + 0.5) + d_features = self.dino_vit(inp) + + prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, + t], a_c[:, t].unsqueeze( + 0 + ), first_c[:, + t].unsqueeze( + 0) + a_t = a_t * (1 - first_t) + + slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) + # FIXME: prev_slots was not used properly, need to rerun test + if self.use_prev_slots: + prev_slots = slots_t + else: + prev_slots = None + + prior, posterior, diff = self.recurrent_model.forward( + prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = 0 + decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs_detached = decoded_imgs_detached * img_mask + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) + feat_mask = F.softmax(masks, dim=1) + decoded_feats = decoded_feats * feat_mask + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + + (1-self.vit_l2_ratio) * img_rec) + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + losses['loss_kl_reg'] + losses['loss_discount_pred']) + + return losses, posterior, metrics + diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index c099281..a2212c1 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -114,6 +114,9 @@ def get_action(self, obs: Observation) -> Action: actor_dist = self.actor.get_action(self._state) self._last_action = actor_dist.sample() + if self.is_discrete: + self._action_probs += actor_dist.probs.squeeze() + if self.is_discrete: return self._last_action.squeeze().detach().cpu().numpy().argmax() else: diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index fdc0129..ee3cfc6 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -11,12 +11,12 @@ world_model: latent_dim: 32 latent_classes: 32 rssm_dim: 200 - kl_loss_scale: 1.0 + kl_loss_scale: 1e1 kl_loss_balancing: 0.8 - kl_free_nats: 0.05 + kl_free_nats: 0.00 discrete_rssm: false - decode_vit: true - vit_l2_ratio: 1.0 + decode_vit: false + vit_l2_ratio: 0.8 encode_vit: false predict_discount: false layer_norm: ${..layer_norm} diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 0f5423c..ef299bb 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -1,33 +1,67 @@ _target_: rl_sandbox.agents.DreamerV2 -layer_norm: true -# World model parameters -batch_cluster_size: 50 -latent_dim: 32 -latent_classes: 32 -rssm_dim: 1024 -kl_loss_scale: 1.0 -kl_loss_balancing: 0.8 -kl_loss_free_nats: 0.0 -world_model_lr: 2e-4 -world_model_predict_discount: true -# ActorCritic parameters -discount_factor: 0.999 imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: 32 + rssm_dim: 1024 + kl_loss_scale: 1.0 + kl_loss_balancing: 0.8 + kl_free_nats: 0.00 + discrete_rssm: false + decode_vit: true + vit_l2_ratio: 1.0 + encode_vit: false + predict_discount: true + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 3e-3 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} -actor_lr: 2e-4 -# automatically chooses depending on discrete/continuous env -actor_reinforce_fraction: null -actor_entropy_scale: 3e-3 +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: null + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 -critic_lr: 2e-4 -# Lambda parameter for trainin deeper multi-step prediction -critic_value_target_lambda: 0.95 -critic_update_interval: 100 -# [0-1], 1 means hard update -critic_soft_update_fraction: 1 +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 -discrete_rssm: false -decode_vit: true -vit_l2_ratio: 0.5 -encode_vit: false +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml new file mode 100644 index 0000000..d2fd6ca --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -0,0 +1,73 @@ +_target_: rl_sandbox.agents.DreamerV2 + +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 22 + latent_classes: ${.latent_dim} + rssm_dim: 256 + slots_num: 6 + slots_iter_num: 2 + kl_loss_scale: 1e1 + kl_loss_balancing: 0.8 + kl_free_nats: 0.00 + discrete_rssm: false + decode_vit: false + use_prev_slots: false + vit_l2_ratio: 1.0 + encode_vit: false + predict_discount: true + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 3e-3 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted.yaml index f395d45..33241af 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted.yaml @@ -8,16 +8,18 @@ world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel _partial_: true batch_cluster_size: ${..batch_cluster_size} - latent_dim: 16 - latent_classes: 16 - rssm_dim: 40 - slots_num: 2 + latent_dim: 22 + latent_classes: 22 + rssm_dim: 80 + slots_num: 4 + slots_iter_num: 5 kl_loss_scale: 2.0 kl_loss_balancing: 0.8 kl_free_nats: 0.05 discrete_rssm: false decode_vit: false vit_l2_ratio: 1.0 + use_prev_slots: true encode_vit: false predict_discount: false layer_norm: ${..layer_norm} diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml new file mode 100644 index 0000000..1d53297 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml @@ -0,0 +1,73 @@ +_target_: rl_sandbox.agents.DreamerV2 + +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_dino.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: ${.latent_dim} + rssm_dim: 256 + slots_num: 6 + slots_iter_num: 2 + kl_loss_scale: 1e2 + kl_loss_balancing: 0.8 + kl_free_nats: 0.00 + discrete_rssm: false + decode_vit: true + use_prev_slots: false + vit_l2_ratio: 0.8 + encode_vit: false + predict_discount: true + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 3e-3 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 2e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index eee688a..fd15765 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,7 +1,7 @@ defaults: - - agent: dreamer_v2 - - env: dm_cartpole - - training: dm + - agent: dreamer_v2_slotted_debug + - env: crafter + - training: crafter - _self_ - override hydra/launcher: joblib @@ -9,15 +9,14 @@ seed: 42 device_type: cuda logger: - type: null - message: Dreamer with DINO features + type: tensorboard + message: Crafter 6 DINO slots, 32 latents, 256 rssm log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 1e3 - + val_logs_every: 2e4 validation: rollout_num: 5 @@ -26,18 +25,18 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator _partial_: true debug: profiler: false hydra: - #mode: MULTIRUN - mode: RUN + mode: MULTIRUN + #mode: RUN launcher: - #n_jobs: 3 - n_jobs: 1 + n_jobs: 8 sweeper: params: - #agent.kl_loss_scale: 64.0,128.0,256.0 + agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + agent.world_model.vit_l2_ratio: 0.1,0.9 diff --git a/rl_sandbox/config/config_dino.yaml b/rl_sandbox/config/config_dino.yaml index eee688a..569fba2 100644 --- a/rl_sandbox/config/config_dino.yaml +++ b/rl_sandbox/config/config_dino.yaml @@ -1,6 +1,6 @@ defaults: - agent: dreamer_v2 - - env: dm_cartpole + - env: dm_quadruped - training: dm - _self_ - override hydra/launcher: joblib @@ -9,14 +9,14 @@ seed: 42 device_type: cuda logger: - type: null - message: Dreamer with DINO features + type: tensorboard + message: Quadruped with DINO features log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 1e3 + val_logs_every: 1e4 validation: @@ -36,8 +36,8 @@ hydra: #mode: MULTIRUN mode: RUN launcher: - #n_jobs: 3 + #n_jobs: 8 n_jobs: 1 sweeper: - params: - #agent.kl_loss_scale: 64.0,128.0,256.0 + #params: + # agent.world_model.kl_loss_scale: 1e-4,1e-3,1e-2,0.1,1.0,1e2,1e3,1e4 diff --git a/rl_sandbox/config/config_dino_1.yaml b/rl_sandbox/config/config_dino_1.yaml new file mode 100644 index 0000000..3c082c1 --- /dev/null +++ b/rl_sandbox/config/config_dino_1.yaml @@ -0,0 +1,43 @@ +defaults: + - agent: dreamer_v2 + - env: dm_cheetah + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Cheetah with DINO features, 0.75 ratio + log_grads: false + +training: + checkpoint_path: null + steps: 2e6 + val_logs_every: 2e4 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 8 + n_jobs: 1 + #sweeper: + #params: + # agent.world_model.kl_loss_scale: 1e-4,1e-3,1e-2,0.1,1.0,1e2,1e3,1e4 diff --git a/rl_sandbox/config/config_dino_2.yaml b/rl_sandbox/config/config_dino_2.yaml new file mode 100644 index 0000000..fd42a77 --- /dev/null +++ b/rl_sandbox/config/config_dino_2.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2 + - env: dm_acrobot + - training: dm + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Acrobot default + log_grads: false + +training: + checkpoint_path: null + steps: 2e6 + val_logs_every: 2e4 + + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 8 + n_jobs: 1 + #sweeper: + #params: + # agent.world_model.kl_loss_scale: 1e-4,1e-3,1e-2,0.1,1.0,1e2,1e3,1e4 + diff --git a/rl_sandbox/config/config_slotted.yaml b/rl_sandbox/config/config_slotted.yaml index 7230c59..f28cac8 100644 --- a/rl_sandbox/config/config_slotted.yaml +++ b/rl_sandbox/config/config_slotted.yaml @@ -1,5 +1,5 @@ defaults: - - agent: dreamer_v2_slotted + - agent: dreamer_v2_slotted_debug - env: dm_cartpole - training: dm - _self_ @@ -9,17 +9,14 @@ seed: 42 device_type: cuda logger: - type: null - message: Cartpole 2 slots, double encoder, reduced warmup, 384 n_dim, correct prev_slots, 0.05 nats, 40 rssm dims, 16x16 stoch, 2x KL + type: tensorboard + message: Cartpole 4 slots, 384 n_dim, 80 rssm dims, 22x22 stoch log_grads: false training: checkpoint_path: null steps: 1e6 - prefill: 0 - pretrain: 1 - batch_size: 2 - val_logs_every: 1e3 + val_logs_every: 5e4 validation: @@ -36,11 +33,12 @@ debug: profiler: false hydra: - #mode: MULTIRUN - mode: RUN + mode: MULTIRUN + #mode: RUN launcher: - #n_jobs: 3 - n_jobs: 1 + n_jobs: 8 + #n_jobs: 1 sweeper: params: - #agent.kl_loss_scale: 64.0,128.0,256.0 + agent.world_model.kl_loss_scale: 0.1,1e2,1e3,1e4 + agent.world_model.kl_free_nats: 0,1e-2 diff --git a/rl_sandbox/config/config_slotted_debug.yaml b/rl_sandbox/config/config_slotted_debug.yaml new file mode 100644 index 0000000..4eaef5d --- /dev/null +++ b/rl_sandbox/config/config_slotted_debug.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2_slotted_debug + - env: crafter + - training: crafter + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Crafter 5 slots, 1e2 kl loss, 0.999 vit + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + #n_jobs: 4 + n_jobs: 1 + #sweeper: + # params: + # agent.world_model.kl_loss_scale: 1e3 + #agent.world_model.latent_dim: 22,32 + #agent.world_model.rssm_dim: 128,256 diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 39c259c..987190c 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -16,10 +16,10 @@ def __init__(self, agent: 'DreamerV2', log_video: bool = False): self.episode = 0 self.log_video = log_video - def on_step(self): + def on_step(self, logger): pass - def on_episode(self, logger, rollout): + def on_episode(self, logger): pass def on_val(self, logger, rollouts: list[Rollout]): @@ -51,25 +51,27 @@ def __init__(self, agent: 'DreamerV2'): if agent.is_discrete: pass + self.reset_ep() + def reset_ep(self): - self._latent_probs = torch.zeros((self.agent.world_model.latent_classes, self.agent.world_model.latent_dim), device=agent.device) + self._latent_probs = torch.zeros((self.agent.world_model.latent_classes, self.agent.world_model.latent_dim), device=self.agent.device) self._action_probs = torch.zeros((self.agent.actions_num), device=self.agent.device) self.stored_steps = 0 - def on_step(self): + def on_step(self, logger): self.stored_steps += 1 if self.agent.is_discrete: self._action_probs += self._action_probs - self._latent_probs += self.agent._state[0].stoch_dist.probs.squeeze().mean(dim=0) + self._latent_probs += self.agent._state.stoch_dist.probs.squeeze().mean(dim=0) - def on_episode(self, logger, rollout): - latent_hist = (self.agent._latent_probs / self._stored_steps).detach().cpu().numpy() + def on_episode(self, logger): + latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) # if discrete action space if self.agent.is_discrete: - action_hist = (self.agent._action_probs / self._stored_steps).detach().cpu().numpy() + action_hist = (self.agent._action_probs / self.stored_steps).detach().cpu().numpy() fig = plt.Figure() ax = fig.add_axes([0, 0, 1, 1]) ax.bar(np.arange(self.agent.actions_num), action_hist) @@ -154,6 +156,13 @@ def viz_log(self, rollout, logger, epoch_num): class SlottedDreamerMetricsEvaluator(DreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.probs.squeeze().mean(dim=0) + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): obs = torch.from_numpy(obs.copy()).to(self.agent.device) obs = self.agent.preprocess_obs(obs) @@ -241,3 +250,114 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) + +class SlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + obs = torch.from_numpy(obs.copy()).to(self.agent.device) + obs = self.agent.preprocess_obs(obs) + actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + vit_slots_video = [] + rews = [] + + state = None + prev_slots = None + means = np.array([0.485, 0.456, 0.406]) + stds = np.array([0.229, 0.224, 0.225]) + UnNormalize = tv.transforms.Normalize(list(-means/stds), + list(1/stds)) + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode.cpu().detach().numpy() + + decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + + _, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = F.softmax(vit_masks, dim=1) + upscale = tv.transforms.Resize(64, antialias=True) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + if self.agent.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) + video.append(video_r) + slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) + vit_slots_video.append(per_slot_vit.cpu().detach().numpy()/upscaled_mask.max().cpu().detach().numpy() + 0.5) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode.cpu().detach().numpy() + decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + + _, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = F.softmax(vit_masks, dim=1) + upscale = tv.transforms.Resize(64, antialias=True) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * obs[update_num+1:].to(self.agent.device).unsqueeze(1)) + # per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) + + if self.agent.world_model.encode_vit: + video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + else: + video_r = (video_r + 0.5) + video.append(video_r) + slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) + vit_slots_video.append(per_slot_vit.cpu().detach().numpy()/np.expand_dims(upscaled_mask.cpu().detach().numpy().max(axis=(1,2,3)), axis=(1,2,3,4)) + 0.5) + + return np.concatenate(video), rews, np.concatenate(slots_video), np.concatenate(vit_slots_video) + + def viz_log(self, rollout, logger, epoch_num): + init_indeces = np.random.choice(len(rollout.states) - self.agent.imagination_horizon, 5) + + videos = np.concatenate([ + rollout.next_states[init_idx:init_idx + self.agent.imagination_horizon].transpose( + 0, 3, 1, 2) for init_idx in init_indeces + ], axis=3).astype(np.float32) / 255.0 + + real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] + + videos_r, imagined_rewards, slots_video, vit_masks_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.next_states[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + ]) + videos_r = np.concatenate(videos_r, axis=3) + + slots_video = np.concatenate(list(slots_video)[:3], axis=3) + slots_video = slots_video.transpose((0, 2, 3, 1, 4)) + slots_video = np.expand_dims(slots_video.reshape(*slots_video.shape[:-2], -1), 0) + + videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) + videos_comparison = (videos_comparison * 255.0).astype(np.uint8) + + vit_masks_video = np.concatenate(list(vit_masks_video)[:3], axis=3) + vit_masks_video = vit_masks_video.transpose((0, 2, 3, 1, 4)) + vit_masks_video = np.expand_dims(slots_video.reshape(*vit_masks_video.shape[:-2], -1), 0) + + logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) + logger.add_video('val/dreamed_slots', slots_video, epoch_num) + logger.add_video('val/dreamed_vit_masks', vit_masks_video, epoch_num) + + # FIXME: rewrite sum(...) as (...).sum() + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() + logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) + + logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 96db93e..5d3bcad 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -31,7 +31,7 @@ def val_logs(agent, val_cfg: DictConfig, metrics, env: Env, logger: Logger): metric.on_val(logger, rollouts) -@hydra.main(version_base="1.2", config_path='config', config_name='config_slotted') +@hydra.main(version_base="1.2", config_path='config', config_name='config') def main(cfg: DictConfig): lt.monkey_patch() torch.distributions.Distribution.set_default_validate_args(False) @@ -126,9 +126,15 @@ def main(cfg: DictConfig): if global_step % 100 == 0: logger.log(losses, global_step, mode='train') + for metric in metrics: + metric.on_step(logger) + global_step += cfg.env.repeat_action_num pbar.update(cfg.env.repeat_action_num) + for metric in metrics: + metric.on_episode(logger) + # FIXME: find more appealing solution ### Validation if (global_step % cfg.training.val_logs_every) <= (prev_global_step % diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index f0fdd81..ebde600 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -20,8 +20,8 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int): self.scale = self.n_dim**(-1/2) self.epsilon = 1e-8 - self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) - self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) nn.init.xavier_uniform_(self.slots_logsigma) self.slots_proj = nn.Linear(n_dim, n_dim) From a4d0dad640c619de1e58d574dd0e916dc16d05b3 Mon Sep 17 00:00:00 2001 From: Midren Date: Thu, 8 Jun 2023 17:04:26 +0100 Subject: [PATCH 068/106] Added per-slot statistics and all slot dynamics learning --- rl_sandbox/agents/dreamer/common.py | 36 +++ rl_sandbox/agents/dreamer/rssm.py | 36 +-- rl_sandbox/agents/dreamer/rssm_slots.py | 104 +----- .../agents/dreamer/rssm_slots_attention.py | 191 +++++++++++ .../agents/dreamer/rssm_slots_combined.py | 167 ++++++++++ .../agents/dreamer/world_model_slots.py | 67 ++-- .../dreamer/world_model_slots_attention.py | 306 ++++++++++++++++++ ..._dino.py => world_model_slots_combined.py} | 4 +- rl_sandbox/agents/dreamer_v2.py | 3 +- rl_sandbox/metrics.py | 14 + 10 files changed, 760 insertions(+), 168 deletions(-) create mode 100644 rl_sandbox/agents/dreamer/rssm_slots_attention.py create mode 100644 rl_sandbox/agents/dreamer/rssm_slots_combined.py create mode 100644 rl_sandbox/agents/dreamer/world_model_slots_attention.py rename rl_sandbox/agents/dreamer/{world_model_slots_dino.py => world_model_slots_combined.py} (99%) diff --git a/rl_sandbox/agents/dreamer/common.py b/rl_sandbox/agents/dreamer/common.py index e3efcb2..84aab92 100644 --- a/rl_sandbox/agents/dreamer/common.py +++ b/rl_sandbox/agents/dreamer/common.py @@ -35,3 +35,39 @@ def forward(self, x): def update(self, x): self.mag = self.momentum * self.mag + (1 - self.momentum) * (x.abs().mean()).detach() + + +class GRUCell(nn.Module): + + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, + 3 * hidden_size, + bias=norm is not None, + **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + + diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py index 274d72d..0c65549 100644 --- a/rl_sandbox/agents/dreamer/rssm.py +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, View +from rl_sandbox.agents.dreamer import Dist, View, GRUCell from rl_sandbox.utils.schedulers import LinearScheduler @dataclass @@ -39,40 +39,6 @@ def stack(cls, states: list['State'], dim = 0): torch.cat([state.stoch_logits for state in states], dim=dim), stochs) -# TODO: move to common -class GRUCell(nn.Module): - - def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): - super().__init__() - self._size = hidden_size - self._act = torch.tanh - self._norm = norm - self._update_bias = update_bias - self._layer = nn.Linear(input_size + hidden_size, - 3 * hidden_size, - bias=norm is not None, - **kwargs) - if norm: - self._norm = nn.LayerNorm(3 * hidden_size) - - @property - def state_size(self): - return self._size - - def forward(self, x, h): - state = h - parts = self._layer(torch.concat([x, state], -1)) - if self._norm: - dtype = parts.dtype - parts = self._norm(parts.float()) - parts = parts.to(dtype=dtype) - reset, cand, update = parts.chunk(3, dim=-1) - reset = torch.sigmoid(reset) - cand = self._act(reset * cand) - update = torch.sigmoid(update + self._update_bias) - output = update * cand + (1 - update) * state - return output, output - class Quantize(nn.Module): diff --git a/rl_sandbox/agents/dreamer/rssm_slots.py b/rl_sandbox/agents/dreamer/rssm_slots.py index 21d68b4..ec52393 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots.py +++ b/rl_sandbox/agents/dreamer/rssm_slots.py @@ -4,10 +4,8 @@ import torch from jaxtyping import Bool, Float from torch import nn -from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, View -from rl_sandbox.utils.schedulers import LinearScheduler +from rl_sandbox.agents.dreamer import Dist, View, GRUCell @dataclass @@ -45,94 +43,6 @@ def stack(cls, states: list['State'], dim=0): torch.cat([state.stoch_logits for state in states], dim=dim), stochs) -class GRUCell(nn.Module): - - def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): - super().__init__() - self._size = hidden_size - self._act = torch.tanh - self._norm = norm - self._update_bias = update_bias - self._layer = nn.Linear(input_size + hidden_size, - 3 * hidden_size, - bias=norm is not None, - **kwargs) - if norm: - self._norm = nn.LayerNorm(3 * hidden_size) - - @property - def state_size(self): - return self._size - - def forward(self, x, h): - state = h - parts = self._layer(torch.concat([x, state], -1)) - if self._norm: - dtype = parts.dtype - parts = self._norm(parts.float()) - parts = parts.to(dtype=dtype) - reset, cand, update = parts.chunk(3, dim=-1) - reset = torch.sigmoid(reset) - cand = self._act(reset * cand) - update = torch.sigmoid(update + self._update_bias) - output = update * cand + (1 - update) * state - return output, output - - -class Quantize(nn.Module): - - def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): - super().__init__() - - self.dim = dim - self.n_embed = n_embed - self.decay = decay - self.eps = eps - - embed = torch.randn(dim, n_embed) - self.inp_in = nn.Linear(1024, self.n_embed * self.dim) - self.inp_out = nn.Linear(self.n_embed * self.dim, 1024) - self.register_buffer("embed", embed) - self.register_buffer("cluster_size", torch.zeros(n_embed)) - self.register_buffer("embed_avg", embed.clone()) - - def forward(self, inp): - # input = self.inp_in(inp).reshape(-1, 1, self.n_embed, self.dim) - input = inp.reshape(-1, 1, self.n_embed, self.dim) - inp = input - flatten = input.reshape(-1, self.dim) - dist = (flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + - self.embed.pow(2).sum(0, keepdim=True)) - _, embed_ind = (-dist).max(1) - embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) - embed_ind = embed_ind.view(*input.shape[:-1]) - quantize = self.embed_code(embed_ind) - - if self.training: - embed_onehot_sum = embed_onehot.sum(0) - embed_sum = flatten.transpose(0, 1) @ embed_onehot - - self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, - alpha=1 - self.decay) - self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) - n = self.cluster_size.sum() - cluster_size = ((self.cluster_size + self.eps) / - (n + self.n_embed * self.eps) * n) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) - self.embed.data.copy_(embed_normalized) - - # quantize_out = self.inp_out(quantize.reshape(-1, self.n_embed*self.dim)) - quantize_out = quantize - diff = 0.25 * (quantize_out.detach() - inp).pow(2).mean() + ( - quantize_out - inp.detach()).pow(2).mean() - quantize = inp + (quantize_out - inp).detach() - - return quantize, diff, embed_ind - - def embed_code(self, embed_id): - return F.embedding(embed_id, self.embed.transpose(0, 1)) - - class RSSM(nn.Module): """ Recurrent State Space Model @@ -195,9 +105,6 @@ def __init__(self, ]) # For observation we do not have ensemble - # FIXME: very bad magic number - # img_sz = 4 * 384 # 384x2x2 - # img_sz = 192 img_sz = embed_size self.stoch_net = nn.Sequential( # nn.LayerNorm(hidden_size + img_sz, hidden_size), @@ -207,10 +114,6 @@ def __init__(self, nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) - # self.determ_discretizer = MlpVAE(self.hidden_size) - self.determ_discretizer = Quantize(16, 16) - self.discretizer_scheduler = LinearScheduler(1.0, 0.0, 1_000_000) - self.determ_layer_norm = nn.LayerNorm(hidden_size) def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] @@ -232,11 +135,6 @@ def predict_next(self, prev_state: State, action) -> State: prev_state.determ.flatten(1, 2)) if self.discrete_rssm: raise NotImplementedError("discrete rssm was not adopted for slot attention") - # determ_post, diff, embed_ind = self.determ_discretizer(determ_prior) - # determ_post = determ_post.reshape(determ_prior.shape) - # determ_post = self.determ_layer_norm(determ_post) - # alpha = self.discretizer_scheduler.val - # determ_post = alpha * determ_prior + (1-alpha) * determ_post else: determ_post, diff = determ_prior, 0 diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py new file mode 100644 index 0000000..902c6c6 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -0,0 +1,191 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn +import torch.nn.functional as F + +from rl_sandbox.agents.dreamer import Dist, View, GRUCell + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + + @property + def combined_slots(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist( + self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim=0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), stochs) + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, + latent_dim, + hidden_size, + actions_num, + latent_classes, + discrete_rssm, + norm_layer: nn.LayerNorm | nn.Identity, + embed_size=2 * 2 * 384): + super().__init__() + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True)) + self.determ_recurrent = GRUCell(input_size=hidden_size, + hidden_size=hidden_size, + norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + for _ in range(self.ensemble_num) + ]) + + # For observation we do not have ensemble + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + + self.hidden_attention_proj = nn.Linear(hidden_size, 3*hidden_size) + self.pre_norm = nn.LayerNorm(hidden_size) + + self.fc = nn.Linear(hidden_size, hidden_size) + self.fc_norm = nn.LayerNorm(hidden_size) + + self.attention_block_num = 3 + self.att_scale = hidden_size**(-0.5) + self.eps = 1e-8 + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] + # NOTE: Maybe something smarter can be used instead of + # taking only one random between all ensembles + # NOTE: in Dreamer ensemble_num is always 1 + idx = torch.randint(0, self.ensemble_num, ()) + return dists_per_model[0] + + def predict_next(self, prev_state: State, action) -> State: + x = self.pre_determ_recurrent( + torch.concat([ + prev_state.stoch, + action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) + ], + dim=-1)) + + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(1, 2), + prev_state.determ.flatten(1, 2)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + else: + determ_post, diff = determ_prior, 0 + + determ_post = determ_post.reshape(prev_state.determ.shape) + + # TODO: Introduce self-attention block here ! + # Experiment, when only stochastic part is affected and deterministic is not touched + # We keep flow of gradients through determ block, but updating it with stochastic part + for _ in range(self.attention_block_num): + q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) # 1xBxSlots_numxHidden_size + qk = torch.einsum('lbih,lbjh->lbij', q, k) + + attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) + attn = attn / attn.sum(dim=-1, keepdim=True) + + updates = torch.einsum('lbij,lbjh->lbih', qk, v) + determ_post = determ_post + self.fc(self.fc_norm(updates)) + + # used for KL divergence + predicted_stoch_logits = self.estimate_stochastic_latent(determ_post.reshape(determ_prior.shape)).reshape(prev_state.stoch_logits.shape) + # Size is 1 x B x slots_num x ... + return State(determ_post, + predicted_stoch_logits), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State( + prior.determ, + self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( + 1, 2).reshape(prior.stoch_logits.shape)) + + def forward(self, h_prev: State, embed, action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff + diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py new file mode 100644 index 0000000..32963a0 --- /dev/null +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -0,0 +1,167 @@ +import typing as t +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Float +from torch import nn + +from rl_sandbox.agents.dreamer import Dist, View, GRUCell + + +@dataclass +class State: + determ: Float[torch.Tensor, 'seq batch num_slots determ'] + stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] + stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + + @property + def combined(self): + return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + + @property + def combined_slots(self): + return torch.concat([self.determ, self.stoch], dim=-1) + + @property + def stoch(self): + if self.stoch_ is None: + self.stoch_ = Dist( + self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) + return self.stoch_ + + @property + def stoch_dist(self): + return Dist(self.stoch_logits) + + @classmethod + def stack(cls, states: list['State'], dim=0): + if states[0].stoch_ is not None: + stochs = torch.cat([state.stoch for state in states], dim=dim) + else: + stochs = None + return State(torch.cat([state.determ for state in states], dim=dim), + torch.cat([state.stoch_logits for state in states], dim=dim), stochs) + + +class RSSM(nn.Module): + """ + Recurrent State Space Model + + h_t <- deterministic state which is updated inside GRU + s^_t <- stohastic discrete prior state (used for KL divergence: + better predict future and encode smarter) + s_t <- stohastic discrete posterior state (latent representation of current state) + + h_1 ---> h_2 ---> h_3 ---> + \\ x_1 \\ x_2 \\ x_3 + | \\ | ^ | \\ | ^ | \\ | ^ + v MLP CNN | v MLP CNN | v MLP CNN | + \\ | | \\ | | \\ | | + Ensemble \\ | | Ensemble \\ | | Ensemble \\ | | + \\| | \\| | \\| | + | | | | | | | | | + v v | v v | v v | + | | | + s^_1 s_1 ---| s^_2 s_2 ---| s^_3 s_3 ---| + + """ + + def __init__(self, + latent_dim, + hidden_size, + actions_num, + latent_classes, + discrete_rssm, + norm_layer: nn.LayerNorm | nn.Identity, + slots_num: int, + embed_size=2 * 2 * 384): + super().__init__() + self.slots_num = slots_num + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.ensemble_num = 1 + self.hidden_size = hidden_size + self.discrete_rssm = discrete_rssm + + # Calculate deterministic state from prev stochastic, prev action and prev deterministic + self.pre_determ_recurrent = nn.Sequential( + nn.Linear(latent_dim * latent_classes + actions_num, + hidden_size), # Dreamer 'img_in' + norm_layer(hidden_size), + nn.ELU(inplace=True)) + self.determ_recurrent = GRUCell(input_size=hidden_size*slots_num, + hidden_size=hidden_size*slots_num, + norm=True) # Dreamer gru '_cell' + + # Calculate stochastic state from prior embed + # shared between all ensemble models + self.ensemble_prior_estimator = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' + View((1, -1, latent_dim, self.latent_classes))) + for _ in range(self.ensemble_num) + ]) + + img_sz = embed_size + self.stoch_net = nn.Sequential( + # nn.LayerNorm(hidden_size + img_sz, hidden_size), + nn.Linear(hidden_size + img_sz, hidden_size), # Dreamer 'obs_out' + norm_layer(hidden_size), + nn.ELU(inplace=True), + nn.Linear(hidden_size, + latent_dim * self.latent_classes), # Dreamer 'obs_dist' + View((1, -1, latent_dim, self.latent_classes))) + + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): + dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] + # NOTE: Maybe something smarter can be used instead of + # taking only one random between all ensembles + # NOTE: in Dreamer ensemble_num is always 1 + idx = torch.randint(0, self.ensemble_num, ()) + return dists_per_model[0] + + def predict_next(self, prev_state: State, action) -> State: + x = self.pre_determ_recurrent( + torch.concat([ + prev_state.stoch, + action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) + ], + dim=-1)) + # NOTE: x and determ are actually the same value if sequence of 1 is inserted + x, determ_prior = self.determ_recurrent(x.flatten(2, 3), + prev_state.determ.flatten(2, 3)) + if self.discrete_rssm: + raise NotImplementedError("discrete rssm was not adopted for slot attention") + else: + determ_post, diff = determ_prior, 0 + + # used for KL divergence + # TODO: Test both options (with slot in batch size and in feature dim) + predicted_stoch_logits = self.estimate_stochastic_latent(x.reshape(prev_state.determ.shape)) + # Size is 1 x B x slots_num x ... + return State(determ_post.reshape(prev_state.determ.shape), + predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff + + def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + return State( + prior.determ, + self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( + 1, 2).reshape(prior.stoch_logits.shape)) + + def forward(self, h_prev: State, embed, action) -> tuple[State, State]: + """ + 'h' <- internal state of the world + 'z' <- latent embedding of current observation + 'a' <- action taken on prev step + Returns 'h_next' <- the next next of the world + """ + prior, diff = self.predict_next(h_prev, action) + posterior = self.update_current(prior, embed) + + return prior, posterior, diff + + diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index f8535ea..c01c196 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -53,12 +53,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, embed_size=self.n_dim) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - # self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - self.dino_vit = ViTFeat( - "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", - feat_dim=384, - vit_arch='small', - patch_size=16) + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + # self.dino_vit = ViTFeat( + # "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + # feat_dim=384, + # vit_arch='small', + # patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) @@ -91,11 +91,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, - channel_step=192, - kernel_sizes=[3, 4], - output_channels=self.vit_feat_dim+1, - return_dist=True) + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + channel_step=192, + # kernel_sizes=[5, 5, 4], # for size 224x224 + kernel_sizes=[3, 4], + output_channels=self.vit_feat_dim+1, + return_dist=False) # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, # 64*self.dino_vit.feat_dim, # hidden_size=2048, @@ -214,12 +215,11 @@ def KL(dist1, dist2): if self.decode_vit: inp = obs if not self.encode_vit: - ToTensor = tv.transforms.Compose([ - tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - tv.transforms.Resize(224, antialias=True) - ]) - # ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)) + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) inp = ToTensor(obs + 0.5) d_features = self.dino_vit(inp) @@ -266,17 +266,26 @@ def KL(dist1, dist2): losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() else: - raise NotImplementedError("") - # if self.vit_l2_ratio != 1.0: - # x_r = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) - # img_rec = -x_r.log_prob(obs).float().mean() - # else: - # img_rec = 0 - # x_r_detached = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1).detach()) - # losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() - # d_pred = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)) - # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 14, 14)).float().mean()/4 + - # (1-self.vit_l2_ratio) * img_rec) + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = 0 + decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs_detached = decoded_imgs_detached * img_mask + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) + feat_mask = F.softmax(masks, dim=1) + decoded_feats = decoded_feats * feat_mask + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + + (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits @@ -294,3 +303,5 @@ def KL(dist1, dist2): losses['loss_kl_reg'] + losses['loss_discount_pred']) return losses, posterior, metrics + + diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py new file mode 100644 index 0000000..5be1529 --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -0,0 +1,306 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer.rssm_slots_attention import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, + discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, + decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True): + super().__init__() + self.use_prev_slots = use_prev_slots + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.slots_num = slots_num + self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + + self.n_dim = 384 + + self.recurrent_model = RSSM( + latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + embed_size=self.n_dim) + if encode_vit or decode_vit: + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) + # self.dino_vit = ViTFeat( + # "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + # feat_dim=384, + # vit_arch='small', + # patch_size=16) + self.vit_feat_dim = self.dino_vit.feat_dim + self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.encoder = nn.Sequential( + self.dino_vit, + nn.Flatten(), + # fc_nn_generator(64*self.dino_vit.feat_dim, + # 64*384, + # hidden_size=400, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm) + ) + else: + self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + kernel_sizes=[4, 4, 4], + channel_step=96, + double_conv=True, + flatten_output=False) + + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + + self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.n_dim, self.n_dim)) + + if decode_vit: + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + channel_step=192, + # kernel_sizes=[5, 5, 4], # for size 224x224 + kernel_sizes=[3, 4], + output_channels=self.vit_feat_dim+1, + return_dist=False) + # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, + # 64*self.dino_vit.feat_dim, + # hidden_size=2048, + # num_layers=5, + # intermediate_activation=nn.ELU, + # layer_norm=layer_norm, + # final_activation=DistLayer('mse')) + self.image_predictor = Decoder( + rssm_dim + latent_dim * latent_classes, + norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + output_channels=3+1, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + # Tuple of State-Space state and prev slots + return State( + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.rssm_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes, + self.latent_dim, + device=device), + torch.zeros(seq_size, + batch_size, + self.slots_num, + self.latent_classes * self.latent_dim, + device=device)), None + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).sample() + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: + if state is None or state[0] is None: + state, prev_slots = self.get_initial_state() + else: + if self.use_prev_slots: + state, prev_slots = state + else: + state, prev_slots = state[0], None + embed = self.encoder(obs.unsqueeze(0)) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + + pre_slot_features_t = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(1, -1, self.n_dim)) + + slots_t = self.slot_attention(pre_slot_features_t, prev_slots) + + _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), + action) + return posterior, slots_t + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor): + b, _, h, w = obs.shape # s <- BxHxWx3 + + embed = self.encoder(obs) + embed_with_pos_enc = self.positional_augmenter_inp(embed) + # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + pre_slot_features = self.slot_mlp( + embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) + pre_slot_features_c = pre_slot_features.reshape(b // self.cluster_size, + self.cluster_size, -1, self.n_dim) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2): + KL_ = torch.distributions.kl_divergence + kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), + td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_rhs = KL_( + td.OneHotCategoricalStraightThrough(logits=dist2), + td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) + kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) + return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + inp = obs + if not self.encode_vit: + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + inp = ToTensor(obs + 0.5) + d_features = self.dino_vit(inp) + + prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, + t], a_c[:, t].unsqueeze( + 0 + ), first_c[:, + t].unsqueeze( + 0) + a_t = a_t * (1 - first_t) + + slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) + # FIXME: prev_slots was not used properly, need to rerun test + if self.use_prev_slots: + prev_slots = slots_t + else: + prev_slots = None + + prior, posterior, diff = self.recurrent_model.forward( + prev_state, slots_t.unsqueeze(0), a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + + losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() + else: + img_rec = 0 + decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = F.softmax(masks, dim=1) + decoded_imgs_detached = decoded_imgs_detached * img_mask + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) + feat_mask = F.softmax(masks, dim=1) + decoded_feats = decoded_feats * feat_mask + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + + (1-self.vit_l2_ratio) * img_rec) + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + losses['loss_kl_reg'] + losses['loss_discount_pred']) + + return losses, posterior, metrics + diff --git a/rl_sandbox/agents/dreamer/world_model_slots_dino.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py similarity index 99% rename from rl_sandbox/agents/dreamer/world_model_slots_dino.py rename to rl_sandbox/agents/dreamer/world_model_slots_combined.py index 1476440..aa6f0c6 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_dino.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -7,7 +7,7 @@ from torch.nn import functional as F from rl_sandbox.agents.dreamer import Dist, Normalizer -from rl_sandbox.agents.dreamer.rssm_slots import RSSM, State +from rl_sandbox.agents.dreamer.rssm_slots_combined import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder from rl_sandbox.utils.dists import DistLayer from rl_sandbox.utils.fc_nn import fc_nn_generator @@ -50,6 +50,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, latent_classes, discrete_rssm, norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + slots_num=slots_num, embed_size=self.n_dim) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) @@ -304,3 +305,4 @@ def KL(dist1, dist2): return losses, posterior, metrics + diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index a2212c1..e6ad242 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -140,7 +140,8 @@ def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observation # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=False): losses_wm, discovered_states, metrics_wm = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags) - self.world_model.recurrent_model.discretizer_scheduler.step() + # FIXME: wholely remove discrete RSSM + # self.world_model.recurrent_model.discretizer_scheduler.step() if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: self.image_predictor_optimizer.step(losses_wm['loss_reconstruction_img']) diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 987190c..c52fdac 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -163,6 +163,20 @@ def on_step(self, logger): self._action_probs += self._action_probs self._latent_probs += self.agent._state[0].stoch_dist.probs.squeeze().mean(dim=0) + def on_episode(self, logger): + mu = self.agent.world_model.slot_attention.slots_mu + sigma = self.agent.world_model.slot_attention.slots_logsigma.exp() + mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + + logger.add_image('val/slot_attention_mu', mu_hist, self.episode) + logger.add_image('val/slot_attention_sigma', sigma_hist, self.episode, dataformats='HW') + + logger.add_scalar('val/slot_attention_mu_diff_max', mu_hist.max(), self.episode) + logger.add_scalar('val/slot_attention_sigma_diff_max', sigma_hist.max(), self.episode) + + super().on_episode(logger) + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): obs = torch.from_numpy(obs.copy()).to(self.agent.device) obs = self.agent.preprocess_obs(obs) From 370cfba6f1ceb1515185faa1dcc1c4618e0b01b9 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 20 Jun 2023 11:35:54 +0200 Subject: [PATCH 069/106] Added infra to precalc observation data inside replay buffer --- .../agents/dreamer/rssm_slots_combined.py | 34 ++++ rl_sandbox/agents/dreamer/world_model.py | 26 +-- .../agents/dreamer/world_model_slots.py | 25 +-- .../dreamer/world_model_slots_attention.py | 25 +-- .../dreamer/world_model_slots_combined.py | 25 +-- rl_sandbox/agents/dreamer_v2.py | 47 +++--- rl_sandbox/agents/random_agent.py | 3 +- rl_sandbox/metrics.py | 158 +++++++++--------- rl_sandbox/train.py | 35 ++-- rl_sandbox/utils/replay_buffer.py | 146 +++++++++------- rl_sandbox/utils/replay_buffer_old.py | 138 +++++++++++++++ rl_sandbox/utils/rollout_generation.py | 150 ++++++++--------- 12 files changed, 519 insertions(+), 293 deletions(-) create mode 100644 rl_sandbox/utils/replay_buffer_old.py diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py index 32963a0..3628915 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_combined.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -43,6 +43,40 @@ def stack(cls, states: list['State'], dim=0): torch.cat([state.stoch_logits for state in states], dim=dim), stochs) +class GRUCell(nn.Module): + + def __init__(self, input_size, hidden_size, norm=False, update_bias=-1, **kwargs): + super().__init__() + self._size = hidden_size + self._act = torch.tanh + self._norm = norm + self._update_bias = update_bias + self._layer = nn.Linear(input_size + hidden_size, + 3 * hidden_size, + bias=norm is not None, + **kwargs) + if norm: + self._norm = nn.LayerNorm(3 * hidden_size) + + @property + def state_size(self): + return self._size + + def forward(self, x, h): + state = h + parts = self._layer(torch.concat([x, state], -1)) + if self._norm: + dtype = parts.dtype + parts = self._norm(parts.float()) + parts = parts.to(dtype=dtype) + reset, cand, update = parts.chunk(3, dim=-1) + reset = torch.sigmoid(reset) + cand = self._act(reset * cand) + update = torch.sigmoid(update + self._update_bias) + output = update * cand + (1 - update) * state + return output, output + + class RSSM(nn.Module): """ Recurrent State Space Model diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 4347b69..41818b9 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -102,6 +102,20 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + obs = ToTensor(obs + 0.5) + with torch.no_grad(): + d_features = self.dino_vit(obs.unsqueeze(0)).squeeze().cpu() + return {'d_features': d_features} + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), @@ -127,7 +141,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> Sta return posterior def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - discount: torch.Tensor, first: torch.Tensor): + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -157,15 +171,7 @@ def KL(dist1, dist2, free_nat = True): posteriors = [] if self.decode_vit: - inp = obs - if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - inp = ToTensor(obs + 0.5) - d_features = self.dino_vit(inp) + d_features = additional['d_features'] prev_state = self.get_initial_state(b // self.cluster_size) for t in range(self.cluster_size): diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index c01c196..24c5999 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -126,6 +126,19 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + obs = ToTensor(obs + 0.5) + d_features = self.dino_vit(obs.unsqueeze(0)).squeeze() + return {'d_features': d_features} + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device # Tuple of State-Space state and prev slots @@ -178,7 +191,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - discount: torch.Tensor, first: torch.Tensor): + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -213,15 +226,7 @@ def KL(dist1, dist2): posteriors = [] if self.decode_vit: - inp = obs - if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - inp = ToTensor(obs + 0.5) - d_features = self.dino_vit(inp) + d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) for t in range(self.cluster_size): diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 5be1529..10f7101 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -126,6 +126,19 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + obs = ToTensor(obs + 0.5) + d_features = self.dino_vit(obs.unsqueeze(0)).squeeze() + return {'d_features': d_features} + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device # Tuple of State-Space state and prev slots @@ -178,7 +191,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - discount: torch.Tensor, first: torch.Tensor): + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -213,15 +226,7 @@ def KL(dist1, dist2): posteriors = [] if self.decode_vit: - inp = obs - if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - inp = ToTensor(obs + 0.5) - d_features = self.dino_vit(inp) + d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) for t in range(self.cluster_size): diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index aa6f0c6..d788975 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -127,6 +127,19 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + # (0.229, 0.224, 0.225)), + # tv.transforms.Resize(224, antialias=True)]) + ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + obs = ToTensor(obs + 0.5) + d_features = self.dino_vit(obs).squeeze() + return {'d_features': d_features} + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device # Tuple of State-Space state and prev slots @@ -179,7 +192,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, - discount: torch.Tensor, first: torch.Tensor): + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -214,15 +227,7 @@ def KL(dist1, dist2): posteriors = [] if self.decode_vit: - inp = obs - if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - inp = ToTensor(obs + 0.5) - d_features = self.dino_vit(inp) + d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) for t in range(self.cluster_size): diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index e6ad242..0b9125a 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -3,16 +3,13 @@ import numpy as np import torch -from torch import nn from torch.nn import functional as F import torchvision as tv +from unpackable import unpack from rl_sandbox.agents.rl_agent import RlAgent -from rl_sandbox.utils.fc_nn import fc_nn_generator -from rl_sandbox.utils.replay_buffer import (Action, Actions, Observation, - Observations, Rewards, - TerminationFlags, IsFirstFlags) -from rl_sandbox.utils.optimizer import Optimizer +from rl_sandbox.utils.replay_buffer import (Action, Observation, + RolloutChunks, EnvStep, Rollout) from rl_sandbox.agents.dreamer.world_model import WorldModel, State from rl_sandbox.agents.dreamer.ac import ImaginativeCritic, ImaginativeActor @@ -91,6 +88,16 @@ def reset(self): self._last_action = torch.zeros((1, 1, self.actions_num), device=self.device) self._action_probs = torch.zeros((self.actions_num), device=self.device) + def preprocess(self, rollout: Rollout): + obs = self.preprocess_obs(rollout.obs) + additional = self.world_model.precalc_data(obs.to(self.device)) + return Rollout(obs=obs, + actions=rollout.actions, + rewards=rollout.rewards, + is_finished=rollout.is_finished, + is_first=rollout.is_first, + additional_data=rollout.additional_data | additional) + def preprocess_obs(self, obs: torch.Tensor): # FIXME: move to dataloader in replay buffer order = list(range(len(obs.shape))) @@ -105,8 +112,7 @@ def preprocess_obs(self, obs: torch.Tensor): # return obs.type(torch.float32).permute(order) def get_action(self, obs: Observation) -> Action: - # NOTE: pytorch fails without .copy() only when get_action is called - obs = torch.from_numpy(obs.copy()).to(self.device) + obs = torch.from_numpy(obs).to(self.device) obs = self.preprocess_obs(obs) self._state = self.world_model.get_latent(obs, self._last_action, self._state) @@ -118,7 +124,7 @@ def get_action(self, obs: Observation) -> Action: self._action_probs += actor_dist.probs.squeeze() if self.is_discrete: - return self._last_action.squeeze().detach().cpu().numpy().argmax() + return self._last_action.argmax() else: return self._last_action.squeeze().detach().cpu().numpy() @@ -126,20 +132,18 @@ def from_np(self, arr: np.ndarray): arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr return arr.to(self.device, non_blocking=True) - def train(self, obs: Observations, a: Actions, r: Rewards, next_obs: Observations, - is_finished: TerminationFlags, is_first: IsFirstFlags): - - obs = self.preprocess_obs(self.from_np(obs)) - a = self.from_np(a) + def train(self, rollout_chunks: RolloutChunks): + obs, a, r, is_finished, is_first, additional = unpack(rollout_chunks) + torch.cuda.current_stream().synchronize() + # obs = self.preprocess_obs(self.from_np(obs)) if self.is_discrete: a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() - r = self.from_np(r) - discount_factors = (1 - self.from_np(is_finished).type(torch.float32)) - first_flags = self.from_np(is_first).type(torch.float32) + discount_factors = (1 - is_finished).float() + first_flags = is_first.float() # take some latent embeddings as initial with torch.cuda.amp.autocast(enabled=False): - losses_wm, discovered_states, metrics_wm = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags) + losses_wm, discovered_states, metrics_wm = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags, additional) # FIXME: wholely remove discrete RSSM # self.world_model.recurrent_model.discretizer_scheduler.step() @@ -196,17 +200,18 @@ def save_ckpt(self, epoch_num: int, losses: dict[str, float]): { 'epoch': epoch_num, 'world_model_state_dict': self.world_model.state_dict(), - 'world_model_optimizer_state_dict': self.world_model_optimizer.state_dict(), + 'world_model_optimizer_state_dict': self.world_model_optimizer.optimizer.state_dict(), 'actor_state_dict': self.actor.state_dict(), 'critic_state_dict': self.critic.state_dict(), - 'actor_optimizer_state_dict': self.actor_optimizer.state_dict(), - 'critic_optimizer_state_dict': self.critic_optimizer.state_dict(), + 'actor_optimizer_state_dict': self.actor_optimizer.optimizer.state_dict(), + 'critic_optimizer_state_dict': self.critic_optimizer.optimizer.state_dict(), 'losses': losses }, f'dreamerV2-{epoch_num}-{losses["total"]}.ckpt') def load_ckpt(self, ckpt_path: Path): ckpt = torch.load(ckpt_path) self.world_model.load_state_dict(ckpt['world_model_state_dict']) + # FIXME: doesn't work for optimizers self.world_model_optimizer.load_state_dict( ckpt['world_model_optimizer_state_dict']) self.actor.load_state_dict(ckpt['actor_state_dict']) diff --git a/rl_sandbox/agents/random_agent.py b/rl_sandbox/agents/random_agent.py index ea5da16..0fbb1bc 100644 --- a/rl_sandbox/agents/random_agent.py +++ b/rl_sandbox/agents/random_agent.py @@ -1,4 +1,5 @@ import numpy as np +import torch from nptyping import Float, NDArray, Shape from pathlib import Path @@ -13,7 +14,7 @@ def __init__(self, env: Env): self.action_space = env.action_space def get_action(self, obs: State) -> Action | NDArray[Shape["*"],Float]: - return self.action_space.sample() + return torch.from_numpy(np.array(self.action_space.sample())) def train(self, s: States, a: Actions, r: Rewards, next: States, is_finished: TerminationFlags): return dict() diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index c52fdac..0d2acb9 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -26,8 +26,8 @@ def on_val(self, logger, rollouts: list[Rollout]): metrics = self.calculate_metrics(rollouts) logger.log(metrics, self.episode, mode='val') if self.log_video: - video = np.expand_dims(rollouts[0].observations.transpose(0, 3, 1, 2), 0) - logger.add_video('val/visualization', video, self.episode) + video = rollouts[0].obs.unsqueeze(0) + logger.add_video('val/visualization', video.numpy() + 0.5, self.episode) self.episode += 1 def calculate_metrics(self, rollouts: list[Rollout]): @@ -37,7 +37,7 @@ def calculate_metrics(self, rollouts: list[Rollout]): } def _episode_duration(self, rollouts: list[Rollout]): - return np.mean(list(map(lambda x: len(x.states), rollouts))) + return np.mean(list(map(lambda x: len(x.obs), rollouts))) def _episode_return(self, rollouts: list[Rollout]): return np.mean(list(map(lambda x: sum(x.rewards), rollouts))) @@ -89,9 +89,7 @@ def on_val(self, logger, rollouts: list[Rollout]): self.viz_log(rollouts[0], logger, self.episode) def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): - obs = torch.from_numpy(obs.copy()).to(self.agent.device) - obs = self.agent.preprocess_obs(obs) - actions = self.agent.from_np(actions) + # obs = self.agent.preprocess_obs(obs) if self.agent.is_discrete: actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() video = [] @@ -106,46 +104,45 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ if idx > update_num: break state = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) - video_r = self.agent.world_model.image_predictor(state.combined).mode.cpu().detach().numpy() + video_r = self.agent.world_model.image_predictor(state.combined).mode rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) if self.agent.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + video_r = UnNormalize(video_r) else: video_r = (video_r + 0.5) - video.append(video_r) + video.append(video_r.clamp(0, 1)) rews = torch.Tensor(rews).to(obs.device) if update_num < len(obs): states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) - video_r = self.agent.world_model.image_predictor(states.combined[1:]).mode.cpu().detach().numpy() + video_r = self.agent.world_model.image_predictor(states.combined[1:]).mode.detach() if self.agent.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + video_r = UnNormalize(video_r) else: video_r = (video_r + 0.5) video.append(video_r) - return np.concatenate(video), rews + return torch.cat(video), rews def viz_log(self, rollout, logger, epoch_num): - init_indeces = np.random.choice(len(rollout.states) - self.agent.imagination_horizon, 5) + rollout = rollout.to(device=self.agent.device) + init_indeces = np.random.choice(len(rollout.obs) - self.agent.imagination_horizon, 5) - videos = np.concatenate([ - rollout.next_states[init_idx:init_idx + self.agent.imagination_horizon].transpose( - 0, 3, 1, 2) for init_idx in init_indeces - ], axis=3).astype(np.float32) / 255.0 + videos = torch.cat([ + rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces + ], dim=3) + 0.5 real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] - videos_r, imagined_rewards = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( - [rollout.next_states[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + videos_r, imagined_rewards = zip(*[self._generate_video(obs_0, a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) ]) - videos_r = np.concatenate(videos_r, axis=3) + videos_r = torch.cat(videos_r, dim=3) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) - videos_comparison = (videos_comparison * 255.0).astype(np.uint8) + videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) @@ -169,8 +166,8 @@ def on_episode(self, logger): mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) - logger.add_image('val/slot_attention_mu', mu_hist, self.episode) - logger.add_image('val/slot_attention_sigma', sigma_hist, self.episode, dataformats='HW') + logger.add_image('val/slot_attention_mu', mu_hist/mu_hist.max(), self.episode, dataformats='HW') + logger.add_image('val/slot_attention_sigma', sigma_hist/sigma_hist.max(), self.episode, dataformats='HW') logger.add_scalar('val/slot_attention_mu_diff_max', mu_hist.max(), self.episode) logger.add_scalar('val/slot_attention_sigma_diff_max', sigma_hist.max(), self.episode) @@ -178,9 +175,9 @@ def on_episode(self, logger): super().on_episode(logger) def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): - obs = torch.from_numpy(obs.copy()).to(self.agent.device) - obs = self.agent.preprocess_obs(obs) - actions = self.agent.from_np(actions) + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) if self.agent.is_discrete: actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() video = [] @@ -197,21 +194,21 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ if idx > update_num: break state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) - # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode.cpu().detach().numpy() + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) # TODO: try the scaling of softmax as in attention img_mask = F.softmax(masks, dim=1) decoded_imgs = decoded_imgs * img_mask - video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + video_r = torch.sum(decoded_imgs, dim=1) rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) if self.agent.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + video_r = UnNormalize(video_r) else: video_r = (video_r + 0.5) - video.append(video_r) - slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) + video.append(video_r.clamp(0, 1)) + slots_video.append((decoded_imgs + 0.5).clamp(0, 1)) rews = torch.Tensor(rews).to(obs.device) @@ -219,43 +216,42 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) - # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode.cpu().detach().numpy() + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) img_mask = F.softmax(masks, dim=1) decoded_imgs = decoded_imgs * img_mask - video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + video_r = torch.sum(decoded_imgs, dim=1) if self.agent.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + video_r = UnNormalize(video_r) else: video_r = (video_r + 0.5) video.append(video_r) - slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) + slots_video.append(decoded_imgs + 0.5) - return np.concatenate(video), rews, np.concatenate(slots_video) + return torch.cat(video), rews, torch.cat(slots_video) def viz_log(self, rollout, logger, epoch_num): - init_indeces = np.random.choice(len(rollout.states) - self.agent.imagination_horizon, 5) + rollout = rollout.to(device=self.agent.device) + init_indeces = np.random.choice(len(rollout.obs) - self.agent.imagination_horizon, 5) - videos = np.concatenate([ - rollout.next_states[init_idx:init_idx + self.agent.imagination_horizon].transpose( - 0, 3, 1, 2) for init_idx in init_indeces - ], axis=3).astype(np.float32) / 255.0 + videos = torch.cat([ + rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces + ], dim=3) + 0.5 real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] - videos_r, imagined_rewards, slots_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( - [rollout.next_states[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + videos_r, imagined_rewards, slots_video = zip(*[self._generate_video(obs_0, a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) ]) - videos_r = np.concatenate(videos_r, axis=3) + videos_r = torch.cat(videos_r, dim=3) - slots_video = np.concatenate(list(slots_video)[:3], axis=3) - slots_video = slots_video.transpose((0, 2, 3, 1, 4)) - slots_video = np.expand_dims(slots_video.reshape(*slots_video.shape[:-2], -1), 0) + slots_video = torch.cat(list(slots_video)[:3], dim=3) + slots_video = slots_video.permute((0, 2, 3, 1, 4)) + slots_video = slots_video.reshape(*slots_video.shape[:-2], -1).unsqueeze(0) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) - videos_comparison = (videos_comparison * 255.0).astype(np.uint8) + videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) @@ -267,9 +263,9 @@ def viz_log(self, rollout, logger, epoch_num): class SlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): - obs = torch.from_numpy(obs.copy()).to(self.agent.device) - obs = self.agent.preprocess_obs(obs) - actions = self.agent.from_np(actions) + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) if self.agent.is_discrete: actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() video = [] @@ -287,13 +283,13 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ if idx > update_num: break state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) - # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode.cpu().detach().numpy() + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) # TODO: try the scaling of softmax as in attention img_mask = F.softmax(masks, dim=1) decoded_imgs = decoded_imgs * img_mask - video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + video_r = torch.sum(decoded_imgs, dim=1) _, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) @@ -303,12 +299,12 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) if self.agent.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + video_r = UnNormalize(video_r) else: video_r = (video_r + 0.5) - video.append(video_r) - slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) - vit_slots_video.append(per_slot_vit.cpu().detach().numpy()/upscaled_mask.max().cpu().detach().numpy() + 0.5) + video.append(video_r.clamp(0, 1)) + slots_video.append((decoded_imgs + 0.5).clamp(0, 1)) + vit_slots_video.append((per_slot_vit/upscaled_mask.max() + 0.5).clamp(0, 1)) rews = torch.Tensor(rews).to(obs.device) @@ -316,11 +312,11 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) - # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode.cpu().detach().numpy() + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) img_mask = F.softmax(masks, dim=1) decoded_imgs = decoded_imgs * img_mask - video_r = torch.sum(decoded_imgs, dim=1).cpu().detach().numpy() + video_r = torch.sum(decoded_imgs, dim=1) _, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) @@ -330,41 +326,41 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ # per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) if self.agent.world_model.encode_vit: - video_r = UnNormalize(torch.from_numpy(video_r)).numpy() + video_r = UnNormalize(video_r) else: video_r = (video_r + 0.5) - video.append(video_r) - slots_video.append(decoded_imgs.cpu().detach().numpy() + 0.5) - vit_slots_video.append(per_slot_vit.cpu().detach().numpy()/np.expand_dims(upscaled_mask.cpu().detach().numpy().max(axis=(1,2,3)), axis=(1,2,3,4)) + 0.5) + video.append(video_r.clamp(0, 1)) + slots_video.append((decoded_imgs + 0.5).clamp(0, 1)) + vit_slots_video = None # FIXME: this is not correct + # vit_slots_video.append(per_slot_vit/np.expand_dims(upscaled_mask.max(axis=(1,2,3)), axis=(1,2,3,4)) + 0.5) - return np.concatenate(video), rews, np.concatenate(slots_video), np.concatenate(vit_slots_video) + return torch.cat(video), rews, torch.cat(slots_video), torch.cat(vit_slots_video) def viz_log(self, rollout, logger, epoch_num): - init_indeces = np.random.choice(len(rollout.states) - self.agent.imagination_horizon, 5) + rollout = rollout.to(device=self.agent.device) + init_indeces = np.random.choice(len(rollout.obs) - self.agent.imagination_horizon, 5) - videos = np.concatenate([ - rollout.next_states[init_idx:init_idx + self.agent.imagination_horizon].transpose( - 0, 3, 1, 2) for init_idx in init_indeces - ], axis=3).astype(np.float32) / 255.0 + videos = torch.cat([ + rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces + ], dim=3) + 0.5 real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] videos_r, imagined_rewards, slots_video, vit_masks_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( - [rollout.next_states[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) ]) - videos_r = np.concatenate(videos_r, axis=3) + videos_r = torch.cat(videos_r, dim=3) - slots_video = np.concatenate(list(slots_video)[:3], axis=3) - slots_video = slots_video.transpose((0, 2, 3, 1, 4)) - slots_video = np.expand_dims(slots_video.reshape(*slots_video.shape[:-2], -1), 0) + slots_video = torch.cat(list(slots_video)[:3], dim=3) + slots_video = slots_video.permute((0, 2, 3, 1, 4)) + slots_video = slots_video.reshape(*slots_video.shape[:-2], -1).unsqueeze(0) - videos_comparison = np.expand_dims(np.concatenate([videos, videos_r, np.abs(videos - videos_r + 1)/2], axis=2), 0) - videos_comparison = (videos_comparison * 255.0).astype(np.uint8) + videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) - vit_masks_video = np.concatenate(list(vit_masks_video)[:3], axis=3) - vit_masks_video = vit_masks_video.transpose((0, 2, 3, 1, 4)) - vit_masks_video = np.expand_dims(slots_video.reshape(*vit_masks_video.shape[:-2], -1), 0) + vit_masks_video = torch.cat(list(vit_masks_video)[:3], dim=3) + vit_masks_video = vit_masks_video.permute((0, 2, 3, 1, 4)) + vit_masks_video = slots_video.reshape(*vit_masks_video.shape[:-2], -1).unsqueeze(0) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 5d3bcad..3c45597 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -14,7 +14,6 @@ from torch.profiler import ProfilerActivity, profile from tqdm import tqdm -from rl_sandbox.metrics import EpisodeMetricsEvaluator from rl_sandbox.utils.env import Env from rl_sandbox.utils.logger import Logger from rl_sandbox.utils.replay_buffer import ReplayBuffer @@ -26,6 +25,7 @@ def val_logs(agent, val_cfg: DictConfig, metrics, env: Env, logger: Logger): with torch.no_grad(): rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent, collect_obs=True) + rollouts = [agent.preprocess(r) for r in rollouts] for metric in metrics: metric.on_val(logger, rollouts) @@ -65,14 +65,6 @@ def main(cfg: DictConfig): save_video=False, save_episode=False) - buff = ReplayBuffer(prioritize_ends=cfg.training.prioritize_ends, - min_ep_len=cfg.agent.get('batch_cluster_size', 1) * - (cfg.training.prioritize_ends + 1)) - fillup_replay_buffer( - env, buff, - max(cfg.training.prefill, - cfg.training.batch_size * cfg.agent.get('batch_cluster_size', 1))) - is_discrete = isinstance(env.action_space, Discrete) agent = hydra.utils.instantiate( cfg.agent, @@ -82,6 +74,18 @@ def main(cfg: DictConfig): device_type=cfg.device_type, logger=logger) + buff = ReplayBuffer(prioritize_ends=cfg.training.prioritize_ends, + min_ep_len=cfg.agent.get('batch_cluster_size', 1) * + (cfg.training.prioritize_ends + 1), + preprocess_func=agent.preprocess, + device = cfg.device_type) + + fillup_replay_buffer( + env, buff, + max(cfg.training.prefill, + cfg.training.batch_size * cfg.agent.get('batch_cluster_size', 1)), + agent=agent) + metrics = [metric(agent) for metric in hydra.utils.instantiate(cfg.validation.metrics)] prof = profile( @@ -93,10 +97,10 @@ def main(cfg: DictConfig): for i in tqdm(range(int(cfg.training.pretrain)), desc='Pretraining'): if cfg.training.checkpoint_path is not None: break - s, a, r, n, f, first = buff.sample(cfg.training.batch_size, + rollout_chunks = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get( 'batch_cluster_size', 1)) - losses = agent.train(s, a, r, n, f, first) + losses = agent.train(rollout_chunks) logger.log(losses, i, mode='pre_train') val_logs(agent, cfg.validation, metrics, val_env, logger) @@ -110,17 +114,18 @@ def main(cfg: DictConfig): while global_step < cfg.training.steps: ### Training and exploration - for s, a, r, n, f, _ in iter_rollout(env, agent): - buff.add_sample(s, a, r, n, f) + for env_step in iter_rollout(env, agent): + # env_step = agent.preprocess(env_step) + buff.add_sample(env_step) if global_step % cfg.training.train_every == 0: # NOTE: unintuitive that batch_size is now number of total # samples, but not amount of sequences for recurrent model - s, a, r, n, f, first = buff.sample(cfg.training.batch_size, + rollout_chunk = buff.sample(cfg.training.batch_size, cluster_size=cfg.agent.get( 'batch_cluster_size', 1)) - losses = agent.train(s, a, r, n, f, first) + losses = agent.train(rollout_chunk) if cfg.debug.profiler: prof.step() if global_step % 100 == 0: diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 4572376..0b8d541 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -1,40 +1,66 @@ import typing as t from collections import deque -from dataclasses import dataclass +from dataclasses import dataclass, field +from unpackable import unpack +import torch import numpy as np -from nptyping import Bool, Float, Int, NDArray, Shape +from jaxtyping import Bool, Float, Int -Observation = NDArray[Shape["*,*,3"], Int] -State = NDArray[Shape["*"], Float] | Observation -Action = NDArray[Shape["*"], Int] +Observation = Int[torch.Tensor, 'n n 3'] +State = Float[torch.Tensor, 'n'] +Action = Int[torch.Tensor, 'n'] -Observations = NDArray[Shape["*,*,*,3"], Int] -States = NDArray[Shape["*,*"], Float] | Observations -Actions = NDArray[Shape["*,*"], Int] -Rewards = NDArray[Shape["*"], Float] -TerminationFlags = NDArray[Shape["*"], Bool] +Observations = Int[torch.Tensor, 'batch n n 3'] +States = Float[torch.Tensor, 'batch n'] +Actions = Int[torch.Tensor, 'batch n'] +Rewards = Float[torch.Tensor, 'batch'] +TerminationFlags = Bool[torch.Tensor, 'batch'] IsFirstFlags = TerminationFlags +@dataclass +class EnvStep: + obs: Observation + action: Action + reward: float + is_finished: bool + is_first: bool + additional_data: dict[str, Float[torch.Tensor, '...']] = field(default_factory=dict) @dataclass class Rollout: - states: States + obs: Observations actions: Actions rewards: Rewards - next_states: States is_finished: TerminationFlags - observations: t.Optional[Observations] = None + is_first: IsFirstFlags + additional_data: dict[str, Float[torch.Tensor, 'batch ...']] = field(default_factory=dict) def __len__(self): - return len(self.states) + return len(self.obs) + + def to(self, device: str, non_blocking: bool = True): + self.obs = self.obs.to(device, non_blocking=True) + self.actions = self.actions.to(device, non_blocking=True) + self.rewards = self.rewards.to(device, non_blocking=True) + self.is_finished = self.is_finished.to(device, non_blocking=True) + self.is_first = self.is_first.to(device, non_blocking=True) + for k, v in self.additional_data.items(): + self.additional_data[k] = v.to(device, non_blocking = True) + if not non_blocking: + torch.cuda.current_stream().synchronize() + return self + +@dataclass +class RolloutChunks(Rollout): + pass -# TODO: make buffer concurrent-friendly class ReplayBuffer: def __init__(self, max_len=2e6, prioritize_ends: bool = False, min_ep_len: int = 1, + preprocess_func: t.Callable[[Rollout], Rollout] = lambda x: x, device: str = 'cpu'): self.rollouts: deque[Rollout] = deque() self.rollouts_len: deque[int] = deque() @@ -44,17 +70,15 @@ def __init__(self, max_len=2e6, self.max_len = max_len self.total_num = 0 self.device = device + self.preprocess_func = preprocess_func def __len__(self): return self.total_num def add_rollout(self, rollout: Rollout): - if len(rollout.next_states) <= self.min_ep_len: + if len(rollout.obs) <= self.min_ep_len: return - # NOTE: only last next state is stored, all others are induced - # from state on next step - rollout.next_states = np.expand_dims(rollout.next_states[-1], 0) - self.rollouts.append(rollout) + self.rollouts.append(self.preprocess_func(rollout).to(device='cpu')) self.total_num += len(self.rollouts[-1].rewards) self.rollouts_len.append(len(self.rollouts[-1].rewards)) @@ -66,21 +90,29 @@ def add_rollout(self, rollout: Rollout): # Add sample expects that each subsequent sample # will be continuation of last rollout util termination flag true # is encountered - def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + def add_sample(self, env_step: EnvStep): + s, a, r, n, f, additional = unpack(env_step) if self.curr_rollout is None: - self.curr_rollout = Rollout([s], [a], [r], None, [f]) + self.curr_rollout = Rollout([s], [a], [r], [n], [f], {k: [v] for k,v in additional.items()}) else: - self.curr_rollout.states.append(s) + self.curr_rollout.obs.append(s) self.curr_rollout.actions.append(a) self.curr_rollout.rewards.append(r) - self.curr_rollout.is_finished.append(f) + self.curr_rollout.is_finished.append(n) + self.curr_rollout.is_first.append(f) + for k,v in additional.items(): + self.curr_rollout.additional_data[k].append(v) if f: self.add_rollout( - Rollout(np.array(self.curr_rollout.states), - np.array(self.curr_rollout.actions).reshape(len(self.curr_rollout.actions), -1), - np.array(self.curr_rollout.rewards, dtype=np.float32), - np.array([n]), np.array(self.curr_rollout.is_finished))) + Rollout( + torch.stack(self.curr_rollout.obs), + torch.stack(self.curr_rollout.actions).reshape(-1, 1), + torch.Tensor(self.curr_rollout.rewards), + torch.Tensor(self.curr_rollout.is_finished), + torch.Tensor(self.curr_rollout.is_first), + {k: torch.stack(v) for k,v in self.curr_rollout.additional_data.items()}) + ) self.curr_rollout = None def can_sample(self, num: int): @@ -90,21 +122,13 @@ def sample( self, batch_size: int, cluster_size: int = 1 - ) -> tuple[States, Actions, Rewards, States, TerminationFlags, IsFirstFlags]: + ) -> RolloutChunks: # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me - s, a, r, n, t, is_first = [], [], [], [], [], [] - do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > (cluster_size * (self.prioritize_ends + 1)) - tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) - r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), - batch_size, - p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / tot) + s, a, r, t, is_first, additional = [], [], [], [], [], {} + r_indeces = np.random.choice(len(self.rollouts), batch_size, p=np.array(self.rollouts_len) / self.total_num) s_indeces = [] for r_idx in r_indeces: - if r_idx != len(self.rollouts): - rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] - else: - # -1 because we don't have next_state on terminal - rollout, r_len = self.curr_rollout, len(self.curr_rollout.states) - 1 + rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] assert r_len > cluster_size - 1, "Rollout it too small" max_idx = r_len - cluster_size + 1 @@ -114,25 +138,31 @@ def sample( s_idx = np.random.choice(max_idx, 1).item() s_indeces.append(s_idx) - if r_idx == len(self.rollouts): - r_len += 1 - # FIXME: hot-fix for 1d action space, better to find smarter solution - actions = np.array(rollout.actions[s_idx:s_idx + cluster_size]).reshape(cluster_size, -1) - else: - actions = rollout.actions[s_idx:s_idx + cluster_size] - - is_first.append(np.zeros(cluster_size)) + is_first.append(torch.zeros(cluster_size)) if s_idx == 0: is_first[-1][0] = 1 - s.append(rollout.states[s_idx:s_idx + cluster_size]) - a.append(actions) + + s.append(rollout.obs[s_idx:s_idx + cluster_size]) + a.append(rollout.actions[s_idx:s_idx + cluster_size]) r.append(rollout.rewards[s_idx:s_idx + cluster_size]) t.append(rollout.is_finished[s_idx:s_idx + cluster_size]) - if s_idx != r_len - cluster_size: - n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size]) - else: - if cluster_size != 1: - n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size - 1]) - n.append(rollout.next_states) - return (np.concatenate(s), np.concatenate(a), np.concatenate(r, dtype=np.float32), - np.concatenate(n), np.concatenate(t), np.concatenate(is_first)) + for k,v in rollout.additional_data.items(): + if k not in additional: + additional[k] = [] + additional[k].append(v[s_idx:s_idx + cluster_size]) + + return RolloutChunks( + obs=torch.cat(s), + actions=torch.cat(a), + rewards=torch.cat(r).float(), + is_finished=torch.cat(t), + is_first=torch.cat(is_first), + additional_data={k: torch.cat(v) for k,v in additional.items()} + ).to(self.device, non_blocking=True) + + +# TODO: +# [X] Rewrite to use only torch containers +# [X] Add preprocessing step on adding to replay buffer +# [X] Add possibility to store additional auxilary data (dino encodings) +# [ ] (Optional) Utilize torch's dataloader for async sampling diff --git a/rl_sandbox/utils/replay_buffer_old.py b/rl_sandbox/utils/replay_buffer_old.py new file mode 100644 index 0000000..4572376 --- /dev/null +++ b/rl_sandbox/utils/replay_buffer_old.py @@ -0,0 +1,138 @@ +import typing as t +from collections import deque +from dataclasses import dataclass + +import numpy as np +from nptyping import Bool, Float, Int, NDArray, Shape + +Observation = NDArray[Shape["*,*,3"], Int] +State = NDArray[Shape["*"], Float] | Observation +Action = NDArray[Shape["*"], Int] + +Observations = NDArray[Shape["*,*,*,3"], Int] +States = NDArray[Shape["*,*"], Float] | Observations +Actions = NDArray[Shape["*,*"], Int] +Rewards = NDArray[Shape["*"], Float] +TerminationFlags = NDArray[Shape["*"], Bool] +IsFirstFlags = TerminationFlags + + +@dataclass +class Rollout: + states: States + actions: Actions + rewards: Rewards + next_states: States + is_finished: TerminationFlags + observations: t.Optional[Observations] = None + + def __len__(self): + return len(self.states) + +# TODO: make buffer concurrent-friendly +class ReplayBuffer: + + def __init__(self, max_len=2e6, + prioritize_ends: bool = False, + min_ep_len: int = 1, + device: str = 'cpu'): + self.rollouts: deque[Rollout] = deque() + self.rollouts_len: deque[int] = deque() + self.curr_rollout = None + self.min_ep_len = min_ep_len + self.prioritize_ends = prioritize_ends + self.max_len = max_len + self.total_num = 0 + self.device = device + + def __len__(self): + return self.total_num + + def add_rollout(self, rollout: Rollout): + if len(rollout.next_states) <= self.min_ep_len: + return + # NOTE: only last next state is stored, all others are induced + # from state on next step + rollout.next_states = np.expand_dims(rollout.next_states[-1], 0) + self.rollouts.append(rollout) + self.total_num += len(self.rollouts[-1].rewards) + self.rollouts_len.append(len(self.rollouts[-1].rewards)) + + while self.total_num >= self.max_len: + self.total_num -= self.rollouts_len[0] + self.rollouts_len.popleft() + self.rollouts.popleft() + + # Add sample expects that each subsequent sample + # will be continuation of last rollout util termination flag true + # is encountered + def add_sample(self, s: State, a: Action, r: float, n: State, f: bool): + if self.curr_rollout is None: + self.curr_rollout = Rollout([s], [a], [r], None, [f]) + else: + self.curr_rollout.states.append(s) + self.curr_rollout.actions.append(a) + self.curr_rollout.rewards.append(r) + self.curr_rollout.is_finished.append(f) + + if f: + self.add_rollout( + Rollout(np.array(self.curr_rollout.states), + np.array(self.curr_rollout.actions).reshape(len(self.curr_rollout.actions), -1), + np.array(self.curr_rollout.rewards, dtype=np.float32), + np.array([n]), np.array(self.curr_rollout.is_finished))) + self.curr_rollout = None + + def can_sample(self, num: int): + return self.total_num >= num + + def sample( + self, + batch_size: int, + cluster_size: int = 1 + ) -> tuple[States, Actions, Rewards, States, TerminationFlags, IsFirstFlags]: + # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me + s, a, r, n, t, is_first = [], [], [], [], [], [] + do_add_curr = self.curr_rollout is not None and len(self.curr_rollout.states) > (cluster_size * (self.prioritize_ends + 1)) + tot = self.total_num + (len(self.curr_rollout.states) if do_add_curr else 0) + r_indeces = np.random.choice(len(self.rollouts) + int(do_add_curr), + batch_size, + p=np.array(self.rollouts_len + deque([len(self.curr_rollout.states)] if do_add_curr else [])) / tot) + s_indeces = [] + for r_idx in r_indeces: + if r_idx != len(self.rollouts): + rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] + else: + # -1 because we don't have next_state on terminal + rollout, r_len = self.curr_rollout, len(self.curr_rollout.states) - 1 + + assert r_len > cluster_size - 1, "Rollout it too small" + max_idx = r_len - cluster_size + 1 + if self.prioritize_ends: + s_idx = np.random.choice(max_idx - cluster_size + 1, 1).item() + cluster_size - 1 + else: + s_idx = np.random.choice(max_idx, 1).item() + s_indeces.append(s_idx) + + if r_idx == len(self.rollouts): + r_len += 1 + # FIXME: hot-fix for 1d action space, better to find smarter solution + actions = np.array(rollout.actions[s_idx:s_idx + cluster_size]).reshape(cluster_size, -1) + else: + actions = rollout.actions[s_idx:s_idx + cluster_size] + + is_first.append(np.zeros(cluster_size)) + if s_idx == 0: + is_first[-1][0] = 1 + s.append(rollout.states[s_idx:s_idx + cluster_size]) + a.append(actions) + r.append(rollout.rewards[s_idx:s_idx + cluster_size]) + t.append(rollout.is_finished[s_idx:s_idx + cluster_size]) + if s_idx != r_len - cluster_size: + n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size]) + else: + if cluster_size != 1: + n.append(rollout.states[s_idx+1:s_idx+1 + cluster_size - 1]) + n.append(rollout.next_states) + return (np.concatenate(s), np.concatenate(a), np.concatenate(r, dtype=np.float32), + np.concatenate(n), np.concatenate(t), np.concatenate(is_first)) diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index 3642c44..0d2b08d 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -1,110 +1,106 @@ import typing as t +from collections import defaultdict from multiprocessing.synchronize import Lock -from IPython.core.inputtransformer2 import warnings import numpy as np +import torch import torch.multiprocessing as mp +from IPython.core.inputtransformer2 import warnings from unpackable import unpack from rl_sandbox.agents.random_agent import RandomAgent from rl_sandbox.agents.rl_agent import RlAgent from rl_sandbox.utils.env import Env -from rl_sandbox.utils.replay_buffer import (Action, Observation, ReplayBuffer, - Rollout, State) - - -def _async_env_worker(env: Env, obs_queue: mp.Queue, act_queue: mp.Queue): - state, _, terminated = unpack(env.reset()) - obs_queue.put((state, 0, terminated), block=False) - - while not terminated: - action = act_queue.get(block=True) - - new_state, reward, terminated = unpack(env.step(action)) - del action - obs_queue.put((state, reward, terminated), block=False) - - state = new_state - - -def iter_rollout_async( - env: Env, - agent: RlAgent -) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, - None]: - # NOTE: maybe use SharedMemory instead - obs_queue = mp.Queue(1) - a_queue = mp.Queue(1) - p = mp.Process(target=_async_env_worker, args=(env, obs_queue, a_queue)) - p.start() - terminated = False - - while not terminated: - state, reward, terminated = obs_queue.get(block=True) - action = agent.get_action(state) - a_queue.put(action) - yield state, action, reward, None, terminated, state - - -def iter_rollout( - env: Env, - agent: RlAgent, - collect_obs: bool = False -) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, - None]: +from rl_sandbox.utils.replay_buffer import EnvStep, ReplayBuffer, Rollout + +# (Action, Observation, ReplayBuffer, Rollout, State) + +# FIXME: obsolete, need to be updated for new replay buffer +# def _async_env_worker(env: Env, obs_queue: mp.Queue, act_queue: mp.Queue): +# state, _, terminated = unpack(env.reset()) +# obs_queue.put((state, 0, terminated), block=False) + +# while not terminated: +# action = act_queue.get(block=True) + +# new_state, reward, terminated = unpack(env.step(action)) +# del action +# obs_queue.put((state, reward, terminated), block=False) + +# state = new_state + +# def iter_rollout_async( +# env: Env, +# agent: RlAgent +# ) -> t.Generator[tuple[State, Action, float, State, bool, t.Optional[Observation]], None, +# None]: +# # NOTE: maybe use SharedMemory instead +# obs_queue = mp.Queue(1) +# a_queue = mp.Queue(1) +# p = mp.Process(target=_async_env_worker, args=(env, obs_queue, a_queue)) +# p.start() +# terminated = False + +# while not terminated: +# state, reward, terminated = obs_queue.get(block=True) +# action = agent.get_action(state) +# a_queue.put(action) +# yield state, action, reward, None, terminated, state + + +def iter_rollout(env: Env, + agent: RlAgent, + collect_obs: bool = False) -> t.Generator[EnvStep, None, None]: state, _, terminated = unpack(env.reset()) agent.reset() - prev_action = np.zeros_like(agent.get_action(state)) - prev_reward = 0 - prev_terminated = False - while not terminated: - action = agent.get_action(state) - - new_state, reward, terminated = unpack(env.step(action)) + reward = 0.0 + is_first = True + action = torch.zeros_like(agent.get_action(state)) + while not terminated: try: obs = env.render() if collect_obs else None except RuntimeError: # FIXME: hot-fix for Crafter env to work warnings.warn("Cannot render environment, using state instead") obs = state - # if collect_obs and isinstance(env, dmEnv): - yield state, prev_action, prev_reward, new_state, prev_terminated, obs - state = new_state - prev_action = action - prev_reward = reward - prev_terminated = terminated + + # FIXME: works only for crafter + yield EnvStep(obs=torch.from_numpy(state), + action=torch.Tensor(action).squeeze(), + reward=reward, + is_finished=terminated, + is_first=is_first) + is_first = False + + action = agent.get_action(state) + + state, reward, terminated = unpack(env.step(action)) def collect_rollout(env: Env, agent: t.Optional[RlAgent] = None, collect_obs: bool = False) -> Rollout: - s, a, r, n, f, o = [], [], [], [], [], [] + s, a, r, t, f, additional = [], [], [], [], [], defaultdict(list) if agent is None: agent = RandomAgent(env) - for state, action, reward, new_state, terminated, obs in iter_rollout( - env, agent, collect_obs): - s.append(state) + for step in iter_rollout(env, agent, collect_obs): + obs, action, reward, terminated, first, add = unpack(step) + s.append(obs) a.append(action) r.append(reward) - n.append(new_state) - f.append(terminated) - - # FIXME: will break for non-DM - if collect_obs: - o.append(obs) + t.append(terminated) + f.append(first) + for k, v in add.items(): + additional[k].append(v) - # match env: - # case gym.Env(): - # obs = np.stack(list(env.render())) if obs_res is not None else None - # case dmEnv(): - obs = np.array(o) if collect_obs is not None else None - return Rollout(np.array(s), - np.array(a).reshape(len(s), -1), np.array(r, dtype=np.float32), - np.array(n), np.array(f), obs) + return Rollout(torch.stack(s), torch.stack(a).reshape(-1, 1), + torch.Tensor(r).float(), torch.Tensor(t), torch.Tensor(f), + {k: torch.stack(v) + for k, v in additional.items()}) def collect_rollout_num(env: Env, @@ -118,7 +114,7 @@ def collect_rollout_num(env: Env, return rollouts -def fillup_replay_buffer(env: Env, rep_buffer: ReplayBuffer, num: int): +def fillup_replay_buffer(env: Env, rep_buffer: ReplayBuffer, num: int, agent: t.Optional[RlAgent] = None): # TODO: paralelyze while not rep_buffer.can_sample(num): - rep_buffer.add_rollout(collect_rollout(env, collect_obs=False)) + rep_buffer.add_rollout(collect_rollout(env, agent=agent, collect_obs=False)) From 26f83a7b2893368fccc4bf353d204eee8eaf9874 Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 21 Jun 2023 16:44:43 +0200 Subject: [PATCH 070/106] Added scheduling for attention --- .vimspector.json | 2 +- rl_sandbox/agents/dreamer/rssm.py | 3 + rl_sandbox/agents/dreamer/rssm_slots.py | 3 + .../agents/dreamer/rssm_slots_attention.py | 21 ++++- .../agents/dreamer/rssm_slots_combined.py | 3 + rl_sandbox/agents/dreamer/world_model.py | 1 + .../agents/dreamer/world_model_slots.py | 3 +- .../dreamer/world_model_slots_attention.py | 18 ++++- .../dreamer/world_model_slots_combined.py | 1 + .../config/agent/dreamer_v2_crafter.yaml | 8 +- .../agent/dreamer_v2_crafter_slotted.yaml | 18 ++--- .../config/agent/dreamer_v2_slotted.yaml | 77 ------------------- .../config/agent/dreamer_v2_slotted_dino.yaml | 73 ------------------ rl_sandbox/config/config.yaml | 8 +- rl_sandbox/config/config_slotted.yaml | 5 +- rl_sandbox/config/config_slotted_debug.yaml | 4 +- rl_sandbox/config/training/crafter.yaml | 2 +- rl_sandbox/metrics.py | 9 ++- rl_sandbox/train.py | 1 - rl_sandbox/utils/dists.py | 1 + 20 files changed, 78 insertions(+), 183 deletions(-) delete mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted.yaml delete mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml diff --git a/.vimspector.json b/.vimspector.json index 44116be..5bedd5f 100644 --- a/.vimspector.json +++ b/.vimspector.json @@ -40,7 +40,7 @@ "extends": "python-base", "configuration": { "program": "rl_sandbox/train.py", - "args": [] + "args": ["logger.type='tensorboard'", "training.prefill=0", "training.batch_size=4"] } }, "Run dino": { diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py index 0c65549..d99816b 100644 --- a/rl_sandbox/agents/dreamer/rssm.py +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -166,6 +166,9 @@ def estimate_stochastic_latent(self, prev_determ: torch.Tensor): idx = torch.randint(0, self.ensemble_num, ()) return dists_per_model[0] + def on_train_step(self): + pass + def predict_next(self, prev_state: State, action) -> State: diff --git a/rl_sandbox/agents/dreamer/rssm_slots.py b/rl_sandbox/agents/dreamer/rssm_slots.py index ec52393..cbbe403 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots.py +++ b/rl_sandbox/agents/dreamer/rssm_slots.py @@ -115,6 +115,9 @@ def __init__(self, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) + def on_train_step(self): + pass + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index 902c6c6..658cd0e 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from rl_sandbox.agents.dreamer import Dist, View, GRUCell +from rl_sandbox.utils.schedulers import LinearScheduler @dataclass @@ -74,6 +75,9 @@ def __init__(self, latent_classes, discrete_rssm, norm_layer: nn.LayerNorm | nn.Identity, + full_qk_from: int = 1, + symmetric_qk: bool = False, + attention_block_num: int = 3, embed_size=2 * 2 * 384): super().__init__() self.latent_dim = latent_dim @@ -82,6 +86,8 @@ def __init__(self, self.hidden_size = hidden_size self.discrete_rssm = discrete_rssm + self.symmetric_qk = symmetric_qk + # Calculate deterministic state from prev stochastic, prev action and prev deterministic self.pre_determ_recurrent = nn.Sequential( nn.Linear(latent_dim * latent_classes + actions_num, @@ -122,10 +128,14 @@ def __init__(self, self.fc = nn.Linear(hidden_size, hidden_size) self.fc_norm = nn.LayerNorm(hidden_size) - self.attention_block_num = 3 + self.attention_scheduler = LinearScheduler(0.0, 1.0, full_qk_from) + self.attention_block_num = attention_block_num self.att_scale = hidden_size**(-0.5) self.eps = 1e-8 + def on_train_step(self): + self.attention_scheduler.step() + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of @@ -156,15 +166,22 @@ def predict_next(self, prev_state: State, action) -> State: # Experiment, when only stochastic part is affected and deterministic is not touched # We keep flow of gradients through determ block, but updating it with stochastic part for _ in range(self.attention_block_num): - q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) # 1xBxSlots_numxHidden_size + q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) # + if self.symmetric_qk: + k = q qk = torch.einsum('lbih,lbjh->lbij', q, k) attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) attn = attn / attn.sum(dim=-1, keepdim=True) + coeff = self.attention_scheduler.val + attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) + updates = torch.einsum('lbij,lbjh->lbih', qk, v) determ_post = determ_post + self.fc(self.fc_norm(updates)) + self.last_attention = attn.mean(dim=1).squeeze() + # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(determ_post.reshape(determ_prior.shape)).reshape(prev_state.stoch_logits.shape) # Size is 1 x B x slots_num x ... diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py index 3628915..c471c44 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_combined.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -150,6 +150,9 @@ def __init__(self, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) + def on_train_step(self): + pass + def estimate_stochastic_latent(self, prev_determ: torch.Tensor): dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] # NOTE: Maybe something smarter can be used instead of diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 41818b9..9bcbd93 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -142,6 +142,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> Sta def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 24c5999..8a57f13 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -136,7 +136,7 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) obs = ToTensor(obs + 0.5) - d_features = self.dino_vit(obs.unsqueeze(0)).squeeze() + d_features = self.dino_vit(obs) return {'d_features': d_features} def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): @@ -192,6 +192,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 10f7101..5ee43f3 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -20,7 +20,8 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True): + decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + full_qk_from: int = 1, symmetric_qk: bool = False, attention_block_num: int = 3): super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) @@ -50,7 +51,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, latent_classes, discrete_rssm, norm_layer=nn.Identity if layer_norm else nn.LayerNorm, - embed_size=self.n_dim) + embed_size=self.n_dim, + full_qk_from=full_qk_from, + symmetric_qk=symmetric_qk, + attention_block_num=attention_block_num) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) @@ -136,7 +140,7 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) obs = ToTensor(obs + 0.5) - d_features = self.dino_vit(obs.unsqueeze(0)).squeeze() + d_features = self.dino_vit(obs) return {'d_features': d_features} def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): @@ -192,6 +196,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) @@ -229,6 +234,9 @@ def KL(dist1, dist2): d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + + self.last_attn = torch.zeros((self.slots_num, self.slots_num), device=a_c.device) + for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, @@ -249,12 +257,15 @@ def KL(dist1, dist2): prior, posterior, diff = self.recurrent_model.forward( prev_state, slots_t.unsqueeze(0), a_t) prev_state = posterior + self.last_attn += self.recurrent_model.last_attention priors.append(prior) posteriors.append(posterior) # losses['loss_determ_recons'] += diff + self.last_attn /= self.cluster_size + posterior = State.stack(posteriors) prior = State.stack(priors) @@ -298,6 +309,7 @@ def KL(dist1, dist2): losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + metrics['attention_coeff'] = torch.tensor(self.recurrent_model.attention_scheduler.val) metrics['reward_mean'] = r.mean() metrics['reward_std'] = r.std() metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index d788975..dd45b73 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -193,6 +193,7 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index ef299bb..12e7f78 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -1,7 +1,7 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 -batch_cluster_size: 50 +batch_cluster_size: 20 layer_norm: true world_model: @@ -45,7 +45,7 @@ wm_optim: _target_: rl_sandbox.utils.optimizer.Optimizer _partial_: true lr_scheduler: null - lr: 2e-4 + lr: 1e-4 eps: 1e-5 weight_decay: 1e-6 clip: 100 @@ -53,7 +53,7 @@ wm_optim: actor_optim: _target_: rl_sandbox.utils.optimizer.Optimizer _partial_: true - lr: 2e-4 + lr: 1e-4 eps: 1e-5 weight_decay: 1e-6 clip: 100 @@ -61,7 +61,7 @@ actor_optim: critic_optim: _target_: rl_sandbox.utils.optimizer.Optimizer _partial_: true - lr: 2e-4 + lr: 1e-4 eps: 1e-5 weight_decay: 1e-6 clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml index d2fd6ca..e86ead4 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -1,25 +1,25 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 -batch_cluster_size: 50 +batch_cluster_size: 20 layer_norm: true world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel _partial_: true batch_cluster_size: ${..batch_cluster_size} - latent_dim: 22 + latent_dim: 32 latent_classes: ${.latent_dim} - rssm_dim: 256 + rssm_dim: 512 slots_num: 6 slots_iter_num: 2 - kl_loss_scale: 1e1 + kl_loss_scale: 1e2 kl_loss_balancing: 0.8 kl_free_nats: 0.00 discrete_rssm: false - decode_vit: false + decode_vit: true use_prev_slots: false - vit_l2_ratio: 1.0 + vit_l2_ratio: 0.1 encode_vit: false predict_discount: true layer_norm: ${..layer_norm} @@ -51,7 +51,7 @@ wm_optim: - _target_: rl_sandbox.utils.optimizer.WarmupScheduler _partial_: true warmup_steps: 1e3 - lr: 2e-4 + lr: 1e-4 eps: 1e-5 weight_decay: 1e-6 clip: 100 @@ -59,7 +59,7 @@ wm_optim: actor_optim: _target_: rl_sandbox.utils.optimizer.Optimizer _partial_: true - lr: 2e-4 + lr: 1e-4 eps: 1e-5 weight_decay: 1e-6 clip: 100 @@ -67,7 +67,7 @@ actor_optim: critic_optim: _target_: rl_sandbox.utils.optimizer.Optimizer _partial_: true - lr: 2e-4 + lr: 1e-4 eps: 1e-5 weight_decay: 1e-6 clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted.yaml deleted file mode 100644 index 33241af..0000000 --- a/rl_sandbox/config/agent/dreamer_v2_slotted.yaml +++ /dev/null @@ -1,77 +0,0 @@ -_target_: rl_sandbox.agents.DreamerV2 - -imagination_horizon: 15 -batch_cluster_size: 50 -layer_norm: true - -world_model: - _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel - _partial_: true - batch_cluster_size: ${..batch_cluster_size} - latent_dim: 22 - latent_classes: 22 - rssm_dim: 80 - slots_num: 4 - slots_iter_num: 5 - kl_loss_scale: 2.0 - kl_loss_balancing: 0.8 - kl_free_nats: 0.05 - discrete_rssm: false - decode_vit: false - vit_l2_ratio: 1.0 - use_prev_slots: true - encode_vit: false - predict_discount: false - layer_norm: ${..layer_norm} - -actor: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor - _partial_: true - # mixing of reinforce and maximizing value func - # for dm_control it is zero in Dreamer (Atari 1) - reinforce_fraction: null - entropy_scale: 1e-4 - layer_norm: ${..layer_norm} - -critic: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic - _partial_: true - discount_factor: 0.999 - update_interval: 100 - # [0-1], 1 means hard update - soft_update_fraction: 1 - # Lambda parameter for trainin deeper multi-step prediction - value_target_lambda: 0.95 - layer_norm: ${..layer_norm} - -wm_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr_scheduler: - - _target_: rl_sandbox.utils.optimizer.WarmupScheduler - _partial_: true - warmup_steps: 1e3 - #- _target_: rl_sandbox.utils.optimizer.DecayScheduler - # _partial_: true - # decay_rate: 0.5 - # decay_steps: 5e5 - lr: 3e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -actor_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 8e-5 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -critic_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 8e-5 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml deleted file mode 100644 index 1d53297..0000000 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_dino.yaml +++ /dev/null @@ -1,73 +0,0 @@ -_target_: rl_sandbox.agents.DreamerV2 - -imagination_horizon: 15 -batch_cluster_size: 50 -layer_norm: true - -world_model: - _target_: rl_sandbox.agents.dreamer.world_model_slots_dino.WorldModel - _partial_: true - batch_cluster_size: ${..batch_cluster_size} - latent_dim: 32 - latent_classes: ${.latent_dim} - rssm_dim: 256 - slots_num: 6 - slots_iter_num: 2 - kl_loss_scale: 1e2 - kl_loss_balancing: 0.8 - kl_free_nats: 0.00 - discrete_rssm: false - decode_vit: true - use_prev_slots: false - vit_l2_ratio: 0.8 - encode_vit: false - predict_discount: true - layer_norm: ${..layer_norm} - -actor: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor - _partial_: true - # mixing of reinforce and maximizing value func - # for dm_control it is zero in Dreamer (Atari 1) - reinforce_fraction: null - entropy_scale: 3e-3 - layer_norm: ${..layer_norm} - -critic: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic - _partial_: true - discount_factor: 0.999 - update_interval: 100 - # [0-1], 1 means hard update - soft_update_fraction: 1 - # Lambda parameter for trainin deeper multi-step prediction - value_target_lambda: 0.95 - layer_norm: ${..layer_norm} - -wm_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr_scheduler: - - _target_: rl_sandbox.utils.optimizer.WarmupScheduler - _partial_: true - warmup_steps: 1e3 - lr: 2e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -actor_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 2e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -critic_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 2e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index fd15765..859e9ee 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,5 +1,5 @@ defaults: - - agent: dreamer_v2_slotted_debug + - agent: dreamer_v2_crafter - env: crafter - training: crafter - _self_ @@ -10,13 +10,13 @@ device_type: cuda logger: type: tensorboard - message: Crafter 6 DINO slots, 32 latents, 256 rssm + message: Crafter default log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 2e4 + val_logs_every: 5e2 validation: rollout_num: 5 @@ -25,7 +25,7 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator _partial_: true debug: diff --git a/rl_sandbox/config/config_slotted.yaml b/rl_sandbox/config/config_slotted.yaml index f28cac8..0ec9369 100644 --- a/rl_sandbox/config/config_slotted.yaml +++ b/rl_sandbox/config/config_slotted.yaml @@ -10,14 +10,13 @@ device_type: cuda logger: type: tensorboard - message: Cartpole 4 slots, 384 n_dim, 80 rssm dims, 22x22 stoch + message: Cartpole with slot attention, 1e3 kl, 2 iter num, free nats log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 5e4 - + val_logs_every: 1e4 validation: rollout_num: 5 diff --git a/rl_sandbox/config/config_slotted_debug.yaml b/rl_sandbox/config/config_slotted_debug.yaml index 4eaef5d..96f3d87 100644 --- a/rl_sandbox/config/config_slotted_debug.yaml +++ b/rl_sandbox/config/config_slotted_debug.yaml @@ -1,5 +1,5 @@ defaults: - - agent: dreamer_v2_slotted_debug + - agent: dreamer_v2_crafter_slotted - env: crafter - training: crafter - _self_ @@ -10,7 +10,7 @@ device_type: cuda logger: type: tensorboard - message: Crafter 5 slots, 1e2 kl loss, 0.999 vit + message: Crafter 6 DINO slots log_grads: false training: diff --git a/rl_sandbox/config/training/crafter.yaml b/rl_sandbox/config/training/crafter.yaml index 85b4242..2d23c4b 100644 --- a/rl_sandbox/config/training/crafter.yaml +++ b/rl_sandbox/config/training/crafter.yaml @@ -4,5 +4,5 @@ batch_size: 16 pretrain: 1 prioritize_ends: true train_every: 5 -save_checkpoint_every: 2e5 +save_checkpoint_every: 5e5 val_logs_every: 2e4 diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 0d2acb9..a1447e2 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -161,11 +161,16 @@ def on_step(self, logger): self._latent_probs += self.agent._state[0].stoch_dist.probs.squeeze().mean(dim=0) def on_episode(self, logger): - mu = self.agent.world_model.slot_attention.slots_mu - sigma = self.agent.world_model.slot_attention.slots_logsigma.exp() + wm = self.agent.world_model + + mu = wm.slot_attention.slots_mu + sigma = wm.slot_attention.slots_logsigma.exp() mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + if wm.recurrent_model.last_attention is not None: + logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention, self.episode, dataformats='HW') + logger.add_image('val/slot_attention_mu', mu_hist/mu_hist.max(), self.episode, dataformats='HW') logger.add_image('val/slot_attention_sigma', sigma_hist/sigma_hist.max(), self.episode, dataformats='HW') diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 3c45597..f7f6e61 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -115,7 +115,6 @@ def main(cfg: DictConfig): ### Training and exploration for env_step in iter_rollout(env, agent): - # env_step = agent.preprocess(env_step) buff.add_sample(env_step) if global_step % cfg.training.train_every == 0: diff --git a/rl_sandbox/utils/dists.py b/rl_sandbox/utils/dists.py index 8f7032d..532bbf1 100644 --- a/rl_sandbox/utils/dists.py +++ b/rl_sandbox/utils/dists.py @@ -4,6 +4,7 @@ from numbers import Number import typing as t +import numpy as np import torch import torch.distributions as td from torch import nn From f8846b6f74e1c956cd57baa710cbde31bdccc4c9 Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 21 Jun 2023 20:15:03 +0200 Subject: [PATCH 071/106] Added hard/soft mixing and per slot mse mean loss --- .../agents/dreamer/world_model_slots.py | 24 ++++- .../dreamer/world_model_slots_attention.py | 93 +++++++++++++++---- .../dreamer/world_model_slots_combined.py | 91 ++++++++++++++---- rl_sandbox/agents/dreamer_v2.py | 3 +- rl_sandbox/metrics.py | 8 +- 5 files changed, 171 insertions(+), 48 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 8a57f13..c6ebe0b 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -20,7 +20,8 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True): + decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + mask_combination: str = 'soft'): super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) @@ -30,6 +31,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.latent_dim = latent_dim self.latent_classes = latent_classes self.slots_num = slots_num + self.mask_combination = mask_combination self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) self.cluster_size = batch_cluster_size @@ -126,6 +128,18 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=1) + case 'hard': + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} @@ -266,7 +280,7 @@ def KL(dist1, dist2): if not self.decode_vit: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + img_mask = self.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) @@ -274,20 +288,20 @@ def KL(dist1, dist2): else: if self.vit_l2_ratio != 1.0: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + img_mask = self.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() else: img_rec = 0 decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + img_mask = self.slot_mask(masks) decoded_imgs_detached = decoded_imgs_detached * img_mask x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) - feat_mask = F.softmax(masks, dim=1) + feat_mask = self.slot_mask(masks) decoded_feats = decoded_feats * feat_mask d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 5ee43f3..b8178dc 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -21,7 +21,11 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, - full_qk_from: int = 1, symmetric_qk: bool = False, attention_block_num: int = 3): + full_qk_from: int = 1, + symmetric_qk: bool = False, + attention_block_num: int = 3, + mask_combination: str = 'soft', + per_slot_rec_loss: bool = False): super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) @@ -31,6 +35,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.latent_dim = latent_dim self.latent_classes = latent_classes self.slots_num = slots_num + self.mask_combination = mask_combination self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) self.cluster_size = batch_cluster_size @@ -41,6 +46,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encode_vit = encode_vit self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.per_slot_rec_loss = per_slot_rec_loss self.n_dim = 384 @@ -130,6 +136,18 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=1) + case 'hard': + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} @@ -272,36 +290,73 @@ def KL(dist1, dist2): r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) - losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + losses['loss_reconstruction_img'] = torch.tensor(0, device=obs.device) if not self.decode_vit: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * img_mask - x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_mask = self.slot_mask(masks) - losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() - else: - if self.vit_l2_ratio != 1.0: - decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:]))) / img_mask.sum(dim=[2, 3, 4]) + # magic constant that describes the difference between log_prob and mse losses + img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + decoded_imgs = decoded_imgs * img_mask + else: decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() + + losses['loss_reconstruction'] = img_rec + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + + if self.per_slot_rec_loss: + l2_loss = (img_mask*((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + # magic constant that describes the difference between log_prob and mse losses + img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + decoded_imgs = decoded_imgs * img_mask + else: + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() else: img_rec = 0 decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs_detached = decoded_imgs_detached * img_mask - x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) - losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + img_mask = self.slot_mask(masks) + + if self.per_slot_rec_loss: + l2_loss = (img_mask*((decoded_imgs_detached - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + # magic constant that describes the difference between log_prob and mse losses + img_rec_detached = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + else: + decoded_imgs_detached = decoded_imgs_detached * img_mask + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + losses['loss_reconstruction_img'] = img_rec_detached decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) - feat_mask = F.softmax(masks, dim=1) - decoded_feats = decoded_feats * feat_mask - d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + - (1-self.vit_l2_ratio) * img_rec) + feat_mask = self.slot_mask(masks) + + d_obs = d_features.reshape(b, self.vit_feat_dim, 8, 8) + + if self.per_slot_rec_loss: + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(d_obs.shape[1:])))/feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # # magic constant that describes the difference between log_prob and mse losses + d_rec = (l2_loss * normalizing_factor).sum(dim=1).mean()*self.slots_num * 4 + decoded_feats = decoded_feats * feat_mask + else: + decoded_feats = decoded_feats * feat_mask + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index dd45b73..875d45b 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -20,7 +20,9 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True): + decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + mask_combination: str = 'soft', + per_slot_rec_loss: bool = False): super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) @@ -30,6 +32,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.latent_dim = latent_dim self.latent_classes = latent_classes self.slots_num = slots_num + self.mask_combination = mask_combination self.state_size = slots_num * (rssm_dim + latent_dim * latent_classes) self.cluster_size = batch_cluster_size @@ -40,6 +43,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encode_vit = encode_vit self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.per_slot_rec_loss = per_slot_rec_loss self.n_dim = 384 @@ -127,6 +131,18 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=1) + case 'hard': + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} @@ -263,36 +279,73 @@ def KL(dist1, dist2): r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) - losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + losses['loss_reconstruction_img'] = torch.tensor(0, device=obs.device) if not self.decode_vit: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs = decoded_imgs * img_mask - x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_mask = self.slot_mask(masks) - losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() - else: - if self.vit_l2_ratio != 1.0: - decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:]))) / img_mask.sum(dim=[2, 3, 4]) + # magic constant that describes the difference between log_prob and mse losses + img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + decoded_imgs = decoded_imgs * img_mask + else: decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() + + losses['loss_reconstruction'] = img_rec + else: + if self.vit_l2_ratio != 1.0: + decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) + img_mask = self.slot_mask(masks) + + if self.per_slot_rec_loss: + l2_loss = (img_mask*((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + # magic constant that describes the difference between log_prob and mse losses + img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + decoded_imgs = decoded_imgs * img_mask + else: + decoded_imgs = decoded_imgs * img_mask + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() else: img_rec = 0 decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) - decoded_imgs_detached = decoded_imgs_detached * img_mask - x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) - losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + img_mask = self.slot_mask(masks) + + if self.per_slot_rec_loss: + l2_loss = (img_mask*((decoded_imgs_detached - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + # magic constant that describes the difference between log_prob and mse losses + img_rec_detached = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + else: + decoded_imgs_detached = decoded_imgs_detached * img_mask + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) + img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + losses['loss_reconstruction_img'] = img_rec_detached decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) - feat_mask = F.softmax(masks, dim=1) - decoded_feats = decoded_feats * feat_mask - d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + - (1-self.vit_l2_ratio) * img_rec) + feat_mask = self.slot_mask(masks) + + d_obs = d_features.reshape(b, self.vit_feat_dim, 8, 8) + + if self.per_slot_rec_loss: + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) + normalizing_factor = (torch.prod(torch.tensor(d_obs.shape[1:])))/feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # # magic constant that describes the difference between log_prob and mse losses + d_rec = (l2_loss * normalizing_factor).sum(dim=1).mean()*self.slots_num * 4 + decoded_feats = decoded_feats * feat_mask + else: + decoded_feats = decoded_feats * feat_mask + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 0b9125a..45aaf76 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -134,7 +134,8 @@ def from_np(self, arr: np.ndarray): def train(self, rollout_chunks: RolloutChunks): obs, a, r, is_finished, is_first, additional = unpack(rollout_chunks) - torch.cuda.current_stream().synchronize() + if torch.cuda.is_available(): + torch.cuda.current_stream().synchronize() # obs = self.preprocess_obs(self.from_np(obs)) if self.is_discrete: a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index a1447e2..0d1ce59 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -203,7 +203,7 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) # TODO: try the scaling of softmax as in attention - img_mask = F.softmax(masks, dim=1) + img_mask = self.agent.world_model.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) @@ -223,7 +223,7 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + img_mask = self.agent.world_model.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) @@ -292,7 +292,7 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ decoded_imgs, masks = self.agent.world_model.image_predictor(state.combined_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) # TODO: try the scaling of softmax as in attention - img_mask = F.softmax(masks, dim=1) + img_mask = self.agent.world_model.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) @@ -319,7 +319,7 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) - img_mask = F.softmax(masks, dim=1) + img_mask = self.agent.world_model.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) From c1811499399e07cbe86f7e4ae360d70a38b25317 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 24 Jun 2023 14:41:18 +0100 Subject: [PATCH 072/106] Added corresponding configs --- rl_sandbox/agents/dreamer/world_model.py | 2 +- .../config/agent/dreamer_v2_crafter.yaml | 2 +- .../agent/dreamer_v2_crafter_slotted.yaml | 2 +- .../agent/dreamer_v2_slotted_attention.yaml | 76 ++++++++++++++++++ .../agent/dreamer_v2_slotted_combined.yaml | 75 ++++++++++++++++++ .../agent/dreamer_v2_slotted_debug.yaml | 77 +++++++++++++++++++ rl_sandbox/config/config_attention.yaml | 44 +++++++++++ rl_sandbox/config/config_combined.yaml | 43 +++++++++++ rl_sandbox/config/config_default.yaml | 44 +++++++++++ rl_sandbox/config/config_slotted.yaml | 17 ++-- rl_sandbox/config/config_slotted_debug.yaml | 9 +-- rl_sandbox/metrics.py | 2 +- 12 files changed, 376 insertions(+), 17 deletions(-) create mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml create mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml create mode 100644 rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml create mode 100644 rl_sandbox/config/config_attention.yaml create mode 100644 rl_sandbox/config/config_combined.yaml create mode 100644 rl_sandbox/config/config_default.yaml diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 9bcbd93..9ba67f1 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -113,7 +113,7 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: (0.229, 0.224, 0.225)) obs = ToTensor(obs + 0.5) with torch.no_grad(): - d_features = self.dino_vit(obs.unsqueeze(0)).squeeze().cpu() + d_features = self.dino_vit(obs).cpu() return {'d_features': d_features} def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 12e7f78..23842fa 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -1,7 +1,7 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 -batch_cluster_size: 20 +batch_cluster_size: 50 layer_norm: true world_model: diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml index e86ead4..32b0596 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -1,7 +1,7 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 -batch_cluster_size: 20 +batch_cluster_size: 50 layer_norm: true world_model: diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml new file mode 100644 index 0000000..e9124bc --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -0,0 +1,76 @@ +_target_: rl_sandbox.agents.DreamerV2 + +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: ${.latent_dim} + rssm_dim: 512 + slots_num: 6 + slots_iter_num: 2 + kl_loss_scale: 1e2 + kl_loss_balancing: 0.8 + kl_free_nats: 0.00 + discrete_rssm: false + decode_vit: true + full_qk_from: 1e6 + symmetric_qk: false + use_prev_slots: false + attention_block_num: 3 + vit_l2_ratio: 0.1 + encode_vit: false + predict_discount: true + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 3e-3 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + lr: 1e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 1e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 1e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml new file mode 100644 index 0000000..f6fe79b --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -0,0 +1,75 @@ +_target_: rl_sandbox.agents.DreamerV2 + +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: ${.latent_dim} + rssm_dim: 512 + slots_num: 6 + slots_iter_num: 2 + kl_loss_scale: 1e2 # Try a bit higher to enforce more predicting the future instead of reconstruction + kl_loss_balancing: 0.8 + kl_free_nats: 0.00 + discrete_rssm: false + decode_vit: true + use_prev_slots: false + vit_l2_ratio: 0.1 + encode_vit: false + mask_combination: soft + per_slot_rec_loss: true + predict_discount: true + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 3e-3 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + lr: 1e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 1e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 1e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml new file mode 100644 index 0000000..b205f32 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml @@ -0,0 +1,77 @@ +_target_: rl_sandbox.agents.DreamerV2 + +imagination_horizon: 15 +batch_cluster_size: 50 +layer_norm: true + +world_model: + _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + _partial_: true + batch_cluster_size: ${..batch_cluster_size} + latent_dim: 32 + latent_classes: 32 + rssm_dim: 200 + slots_num: 4 + slots_iter_num: 2 + kl_loss_scale: 1000 + kl_loss_balancing: 0.8 + kl_free_nats: 0.0005 + discrete_rssm: false + decode_vit: true + vit_l2_ratio: 0.75 + use_prev_slots: false + encode_vit: false + predict_discount: false + layer_norm: ${..layer_norm} + +actor: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor + _partial_: true + # mixing of reinforce and maximizing value func + # for dm_control it is zero in Dreamer (Atari 1) + reinforce_fraction: null + entropy_scale: 1e-4 + layer_norm: ${..layer_norm} + +critic: + _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic + _partial_: true + discount_factor: 0.999 + update_interval: 100 + # [0-1], 1 means hard update + soft_update_fraction: 1 + # Lambda parameter for trainin deeper multi-step prediction + value_target_lambda: 0.95 + layer_norm: ${..layer_norm} + +wm_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + #- _target_: rl_sandbox.utils.optimizer.DecayScheduler + # _partial_: true + # decay_rate: 0.5 + # decay_steps: 5e5 + lr: 3e-4 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +actor_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 + +critic_optim: + _target_: rl_sandbox.utils.optimizer.Optimizer + _partial_: true + lr: 8e-5 + eps: 1e-5 + weight_decay: 1e-6 + clip: 100 diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml new file mode 100644 index 0000000..fd23f0a --- /dev/null +++ b/rl_sandbox/config/config_attention.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2_slotted_attention + - env: crafter + - training: crafter + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Crafter attention slotted + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #n_jobs: 8 + #sweeper: + # params: + # agent.world_model.full_qk_from: 1,1e6 + # agent.world_model.symmetric_qk: true,false + # agent.world_model.attention_block_num: 1,3 diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml new file mode 100644 index 0000000..3b236dc --- /dev/null +++ b/rl_sandbox/config/config_combined.yaml @@ -0,0 +1,43 @@ +defaults: + - agent: dreamer_v2_slotted_combined + - env: crafter + - training: crafter + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Crafter combined slotted + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 1e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #sweeper: + # params: + # agent.world_model.per_slot_rec_loss: false,true + # agent.world_model.mask_combination: soft,hard + # agent.world_model.kl_loss_scale: 1e2 diff --git a/rl_sandbox/config/config_default.yaml b/rl_sandbox/config/config_default.yaml new file mode 100644 index 0000000..c16ee48 --- /dev/null +++ b/rl_sandbox/config/config_default.yaml @@ -0,0 +1,44 @@ +defaults: + - agent: dreamer_v2_crafter + - env: crafter + - training: crafter + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +logger: + type: tensorboard + message: Crafter default + log_grads: false + +training: + checkpoint_path: null + steps: 1e6 + val_logs_every: 5e2 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + #sweeper: + # params: + # agent.world_model._target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel,rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + # agent.world_model.vit_l2_ratio: 0.1,0.5 + # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + # agent.world_model.vit_l2_ratio: 0.1,0.9 diff --git a/rl_sandbox/config/config_slotted.yaml b/rl_sandbox/config/config_slotted.yaml index 0ec9369..0a218e4 100644 --- a/rl_sandbox/config/config_slotted.yaml +++ b/rl_sandbox/config/config_slotted.yaml @@ -32,12 +32,13 @@ debug: profiler: false hydra: - mode: MULTIRUN - #mode: RUN + #mode: MULTIRUN + mode: RUN launcher: - n_jobs: 8 - #n_jobs: 1 - sweeper: - params: - agent.world_model.kl_loss_scale: 0.1,1e2,1e3,1e4 - agent.world_model.kl_free_nats: 0,1e-2 + n_jobs: 1 + #sweeper: + # params: + # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + # agent.world_model.vit_l2_ratio: 0.1,0.9 + + diff --git a/rl_sandbox/config/config_slotted_debug.yaml b/rl_sandbox/config/config_slotted_debug.yaml index 96f3d87..e717bf3 100644 --- a/rl_sandbox/config/config_slotted_debug.yaml +++ b/rl_sandbox/config/config_slotted_debug.yaml @@ -25,7 +25,7 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator _partial_: true debug: @@ -35,10 +35,9 @@ hydra: #mode: MULTIRUN mode: RUN launcher: - #n_jobs: 4 n_jobs: 1 #sweeper: # params: - # agent.world_model.kl_loss_scale: 1e3 - #agent.world_model.latent_dim: 22,32 - #agent.world_model.rssm_dim: 128,256 + # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + # agent.world_model.vit_l2_ratio: 0.1,0.9 + diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 0d1ce59..e946386 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -168,7 +168,7 @@ def on_episode(self, logger): mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) - if wm.recurrent_model.last_attention is not None: + if hasattr(wm.recurrent_model, 'last_attention'): logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention, self.episode, dataformats='HW') logger.add_image('val/slot_attention_mu', mu_hist/mu_hist.max(), self.episode, dataformats='HW') From cc362aeff0e7cfa16e353e91247f66c34dc50a5f Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 24 Jun 2023 21:30:00 +0100 Subject: [PATCH 073/106] Rewritten hard mixing and simplified per slot reconstruction loss --- .../agents/dreamer/world_model_slots.py | 3 +- .../dreamer/world_model_slots_attention.py | 3 +- .../dreamer/world_model_slots_combined.py | 42 ++++++++++--------- .../agent/dreamer_v2_slotted_attention.yaml | 4 +- .../agent/dreamer_v2_slotted_combined.yaml | 2 +- rl_sandbox/config/config.yaml | 17 ++++---- rl_sandbox/config/config_attention.yaml | 2 +- rl_sandbox/config/config_combined.yaml | 23 +++++----- 8 files changed, 52 insertions(+), 44 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index c6ebe0b..79cd5d3 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -133,7 +133,8 @@ def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: case 'soft': img_mask = F.softmax(masks, dim=1) case 'hard': - img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) case 'qmix': raise NotImplementedError case _: diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index b8178dc..91c4a1e 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -141,7 +141,8 @@ def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: case 'soft': img_mask = F.softmax(masks, dim=1) case 'hard': - img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) case 'qmix': raise NotImplementedError case _: diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 875d45b..c135dc7 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -1,5 +1,6 @@ import typing as t +import math import torch import torch.distributions as td import torchvision as tv @@ -136,7 +137,8 @@ def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: case 'soft': img_mask = F.softmax(masks, dim=1) case 'hard': - img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) case 'qmix': raise NotImplementedError case _: @@ -284,15 +286,14 @@ def KL(dist1, dist2): if not self.decode_vit: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:]))) / img_mask.sum(dim=[2, 3, 4]) + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 - decoded_imgs = decoded_imgs * img_mask + img_rec = l2_loss.mean() * normalizing_factor * 8 else: - decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() @@ -301,29 +302,28 @@ def KL(dist1, dist2): if self.vit_l2_ratio != 1.0: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask*((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 - decoded_imgs = decoded_imgs * img_mask + img_rec = l2_loss.mean() * normalizing_factor * 8 else: - decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() else: img_rec = 0 decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) + decoded_imgs_detached = decoded_imgs_detached * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask*((decoded_imgs_detached - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec_detached = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + img_rec_detached = l2_loss.mean() * normalizing_factor * 8 else: - decoded_imgs_detached = decoded_imgs_detached * img_mask x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) img_rec_detached = -x_r_detached.log_prob(obs).float().mean() @@ -334,18 +334,20 @@ def KL(dist1, dist2): d_obs = d_features.reshape(b, self.vit_feat_dim, 8, 8) + decoded_feats = decoded_feats * feat_mask if self.per_slot_rec_loss: - l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(d_obs.shape[1:])))/feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) # # magic constant that describes the difference between log_prob and mse losses - d_rec = (l2_loss * normalizing_factor).sum(dim=1).mean()*self.slots_num * 4 - decoded_feats = decoded_feats * feat_mask + d_rec = l2_loss.mean() * normalizing_factor * 4 else: - decoded_feats = decoded_feats * feat_mask d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index e9124bc..f0e832d 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -19,9 +19,9 @@ world_model: discrete_rssm: false decode_vit: true full_qk_from: 1e6 - symmetric_qk: false + symmetric_qk: true use_prev_slots: false - attention_block_num: 3 + attention_block_num: 1 vit_l2_ratio: 0.1 encode_vit: false predict_discount: true diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index f6fe79b..68b9ea1 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -1,7 +1,7 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 -batch_cluster_size: 50 +batch_cluster_size: 20 layer_norm: true world_model: diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 859e9ee..0e89318 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,5 +1,5 @@ defaults: - - agent: dreamer_v2_crafter + - agent: dreamer_v2_slotted_combined - env: crafter - training: crafter - _self_ @@ -10,22 +10,22 @@ device_type: cuda logger: type: tensorboard - message: Crafter default + message: Crafter combined slotted log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 5e2 + val_logs_every: 1e4 validation: - rollout_num: 5 + rollout_num: 3 visualize: true metrics: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator _partial_: true debug: @@ -38,5 +38,8 @@ hydra: n_jobs: 8 sweeper: params: - agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 - agent.world_model.vit_l2_ratio: 0.1,0.9 + agent.world_model.per_slot_rec_loss: false,true + agent.world_model.mask_combination: soft,hard + agent.world_model.kl_loss_scale: 1e3 + agent.world_model.vit_l2_ratio: 0.9 + diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index fd23f0a..5df0bdb 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -36,7 +36,7 @@ hydra: mode: RUN launcher: n_jobs: 1 - #n_jobs: 8 + # n_jobs: 8 #sweeper: # params: # agent.world_model.full_qk_from: 1,1e6 diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index 3b236dc..b16cd88 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -10,16 +10,16 @@ device_type: cuda logger: type: tensorboard - message: Crafter combined slotted + message: Crafter combined slotted 10x reconstruction log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 1e4 + val_logs_every: 5e3 validation: - rollout_num: 5 + rollout_num: 3 visualize: true metrics: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator @@ -32,12 +32,13 @@ debug: profiler: false hydra: - #mode: MULTIRUN - mode: RUN + mode: MULTIRUN + #mode: RUN launcher: - n_jobs: 1 - #sweeper: - # params: - # agent.world_model.per_slot_rec_loss: false,true - # agent.world_model.mask_combination: soft,hard - # agent.world_model.kl_loss_scale: 1e2 + n_jobs: 8 + sweeper: + params: + agent.world_model.per_slot_rec_loss: true + agent.world_model.mask_combination: soft + agent.world_model.kl_loss_scale: 1e2,1e4 + agent.world_model.vit_l2_ratio: 0.1,0.01,0.9,0.99 From 5fcbc562d0f3b59df0e13058395e31518a497066 Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 3 Jul 2023 21:15:06 +0100 Subject: [PATCH 074/106] Performance improvements --- rl_sandbox/agents/dreamer/ac.py | 11 +++++++---- rl_sandbox/agents/dreamer_v2.py | 11 ++++++----- rl_sandbox/config/agent/dreamer_v2_crafter.yaml | 2 +- rl_sandbox/config/config_default.yaml | 2 +- rl_sandbox/config/training/crafter.yaml | 1 + rl_sandbox/config/training/dm.yaml | 1 + rl_sandbox/train.py | 1 + rl_sandbox/utils/rollout_generation.py | 6 ++++-- 8 files changed, 22 insertions(+), 13 deletions(-) diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py index d774fda..ca8a4b8 100644 --- a/rl_sandbox/agents/dreamer/ac.py +++ b/rl_sandbox/agents/dreamer/ac.py @@ -58,8 +58,8 @@ def _lambda_return(self, vs: torch.Tensor, rs: torch.Tensor, ds: torch.Tensor): (1 - self.lambda_) * vs[i + 1] + self.lambda_ * v_lambdas[-1]) v_lambdas.append(v_lambda) - # FIXME: it copies array, so it is quite slow - return torch.stack(v_lambdas).flip(dims=(0, ))[:-1] + reversed_indices = torch.arange(len(v_lambdas)-1, -1, -1) + return torch.stack(v_lambdas)[reversed_indices][:-1] def lambda_return(self, zs, rs, ds): vs = self.target_critic(zs).mode @@ -119,8 +119,11 @@ def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, baseline: torch.Ten advantage = (vs - baseline).detach() losses['loss_actor_reinforce'] = -(self.rho * action_dists.log_prob( actions.detach()).unsqueeze(2) * discount_factors * advantage).mean() - losses['loss_actor_dynamics_backprop'] = -((1 - self.rho) * - (vs * discount_factors)).mean() + if self.rho != 1.0: + losses['loss_actor_dynamics_backprop'] = -((1 - self.rho) * + (vs * discount_factors)).mean() + else: + losses['loss_actor_dynamics_backprop'] = torch.tensor(0) def calculate_entropy(dist): # return dist.base_dist.base_dist.entropy().unsqueeze(2) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 45aaf76..42d407b 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -31,6 +31,7 @@ def __init__( critic_optim: t.Any, layer_norm: bool, batch_cluster_size: int, + f16_precision: bool, device_type: str = 'cpu', logger = None): @@ -39,6 +40,7 @@ def __init__( self.imagination_horizon = imagination_horizon self.actions_num = actions_num self.is_discrete = (action_type != 'continuous') + self.is_f16 = f16_precision self.world_model: WorldModel = world_model(actions_num=actions_num).to(device_type) self.actor: ImaginativeActor = actor(latent_dim=self.world_model.state_size, @@ -46,8 +48,8 @@ def __init__( is_discrete=self.is_discrete).to(device_type) self.critic: ImaginativeCritic = critic(latent_dim=self.world_model.state_size).to(device_type) - self.world_model_optimizer = wm_optim(model=self.world_model) - self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor) + self.world_model_optimizer = wm_optim(model=self.world_model, scaler=self.is_f16) + self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor, scaler=self.is_f16) self.actor_optimizer = actor_optim(model=self.actor) self.critic_optimizer = critic_optim(model=self.critic) @@ -143,7 +145,7 @@ def train(self, rollout_chunks: RolloutChunks): first_flags = is_first.float() # take some latent embeddings as initial - with torch.cuda.amp.autocast(enabled=False): + with torch.cuda.amp.autocast(enabled=self.is_f16): losses_wm, discovered_states, metrics_wm = self.world_model.calculate_loss(obs, a, r, discount_factors, first_flags, additional) # FIXME: wholely remove discrete RSSM # self.world_model.recurrent_model.discretizer_scheduler.step() @@ -154,8 +156,7 @@ def train(self, rollout_chunks: RolloutChunks): metrics_wm |= self.world_model_optimizer.step(losses_wm['loss_wm']) - with torch.cuda.amp.autocast(enabled=False): - losses_ac = {} + with torch.cuda.amp.autocast(enabled=self.is_f16): initial_states = discovered_states.__class__(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_logits.flatten(0, 1).unsqueeze(0).detach(), discovered_states.stoch_.flatten(0, 1).unsqueeze(0).detach()) diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 23842fa..26a8c91 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -15,7 +15,7 @@ world_model: kl_loss_balancing: 0.8 kl_free_nats: 0.00 discrete_rssm: false - decode_vit: true + decode_vit: false vit_l2_ratio: 1.0 encode_vit: false predict_discount: true diff --git a/rl_sandbox/config/config_default.yaml b/rl_sandbox/config/config_default.yaml index c16ee48..910d50d 100644 --- a/rl_sandbox/config/config_default.yaml +++ b/rl_sandbox/config/config_default.yaml @@ -16,7 +16,7 @@ logger: training: checkpoint_path: null steps: 1e6 - val_logs_every: 5e2 + val_logs_every: 2e4 validation: rollout_num: 5 diff --git a/rl_sandbox/config/training/crafter.yaml b/rl_sandbox/config/training/crafter.yaml index 2d23c4b..ba1943f 100644 --- a/rl_sandbox/config/training/crafter.yaml +++ b/rl_sandbox/config/training/crafter.yaml @@ -1,6 +1,7 @@ steps: 1e6 prefill: 10000 batch_size: 16 +f16_precision: false pretrain: 1 prioritize_ends: true train_every: 5 diff --git a/rl_sandbox/config/training/dm.yaml b/rl_sandbox/config/training/dm.yaml index fd32015..a4328ba 100644 --- a/rl_sandbox/config/training/dm.yaml +++ b/rl_sandbox/config/training/dm.yaml @@ -6,3 +6,4 @@ prioritize_ends: false train_every: 5 save_checkpoint_every: 2e6 val_logs_every: 2e4 +f16_precision: false diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index f7f6e61..74a8ffd 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -72,6 +72,7 @@ def main(cfg: DictConfig): actions_num=env.action_space.n if is_discrete else env.action_space.shape[0], action_type='discrete' if is_discrete else 'continuous', device_type=cfg.device_type, + f16_precision=cfg.training.f16_precision, logger=logger) buff = ReplayBuffer(prioritize_ends=cfg.training.prioritize_ends, diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index 0d2b08d..7ae5423 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -56,7 +56,8 @@ def iter_rollout(env: Env, reward = 0.0 is_first = True - action = torch.zeros_like(agent.get_action(state)) + with torch.no_grad(): + action = torch.zeros_like(agent.get_action(state)) while not terminated: try: @@ -74,7 +75,8 @@ def iter_rollout(env: Env, is_first=is_first) is_first = False - action = agent.get_action(state) + with torch.no_grad(): + action = agent.get_action(state) state, reward, terminated = unpack(env.step(action)) From 42f04c9eebf677371bf83bc0404fc89a995c2c75 Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 3 Jul 2023 21:18:39 +0100 Subject: [PATCH 075/106] Fixed that layer normalization was not applying --- rl_sandbox/agents/dreamer/world_model.py | 8 ++++---- rl_sandbox/agents/dreamer/world_model_slots.py | 8 ++++---- rl_sandbox/agents/dreamer/world_model_slots_attention.py | 8 ++++---- rl_sandbox/agents/dreamer/world_model_slots_combined.py | 8 ++++---- rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml | 2 +- rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml | 4 ++-- rl_sandbox/utils/fc_nn.py | 2 +- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 9ba67f1..718ac35 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -42,7 +42,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, latent_classes, discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm) + norm_layer=nn.LayerNorm if layer_norm else nn.Identity) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) @@ -63,7 +63,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # layer_norm=layer_norm) ) else: - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4, 4], channel_step=48, double_conv=False) @@ -71,7 +71,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(self.state_size, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, channel_step=192, kernel_sizes=[3, 4], output_channels=self.vit_feat_dim, @@ -84,7 +84,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # layer_norm=layer_norm, # final_activation=DistLayer('mse')) self.image_predictor = Decoder(self.state_size, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm) + norm_layer=nn.GroupNorm if layer_norm else nn.Identity) self.reward_predictor = fc_nn_generator(self.state_size, 1, diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 79cd5d3..9636db3 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -51,7 +51,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, latent_classes, discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity, embed_size=self.n_dim) if encode_vit or decode_vit: # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) @@ -77,7 +77,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # layer_norm=layer_norm) ) else: - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], channel_step=96, double_conv=True, @@ -93,7 +93,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, channel_step=192, # kernel_sizes=[5, 5, 4], # for size 224x224 kernel_sizes=[3, 4], @@ -108,7 +108,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # final_activation=DistLayer('mse')) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, output_channels=3+1, return_dist=False) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 91c4a1e..cc62fbf 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -56,7 +56,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, latent_classes, discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity, embed_size=self.n_dim, full_qk_from=full_qk_from, symmetric_qk=symmetric_qk, @@ -85,7 +85,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # layer_norm=layer_norm) ) else: - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], channel_step=96, double_conv=True, @@ -101,7 +101,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, channel_step=192, # kernel_sizes=[5, 5, 4], # for size 224x224 kernel_sizes=[3, 4], @@ -116,7 +116,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # final_activation=DistLayer('mse')) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, output_channels=3+1, return_dist=False) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index c135dc7..900bf7f 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -54,7 +54,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, latent_classes, discrete_rssm, - norm_layer=nn.Identity if layer_norm else nn.LayerNorm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity, slots_num=slots_num, embed_size=self.n_dim) if encode_vit or decode_vit: @@ -81,7 +81,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # layer_norm=layer_norm) ) else: - self.encoder = Encoder(norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], channel_step=96, double_conv=True, @@ -97,7 +97,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, channel_step=192, # kernel_sizes=[5, 5, 4], # for size 224x224 kernel_sizes=[3, 4], @@ -112,7 +112,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # final_activation=DistLayer('mse')) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, - norm_layer=nn.Identity if layer_norm else nn.GroupNorm, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, output_channels=3+1, return_dist=False) diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index f0e832d..ea0cbd4 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -18,7 +18,7 @@ world_model: kl_free_nats: 0.00 discrete_rssm: false decode_vit: true - full_qk_from: 1e6 + full_qk_from: 4e4 symmetric_qk: true use_prev_slots: false attention_block_num: 1 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index 68b9ea1..582ce33 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -1,7 +1,7 @@ _target_: rl_sandbox.agents.DreamerV2 imagination_horizon: 15 -batch_cluster_size: 20 +batch_cluster_size: 50 layer_norm: true world_model: @@ -22,7 +22,7 @@ world_model: vit_l2_ratio: 0.1 encode_vit: false mask_combination: soft - per_slot_rec_loss: true + per_slot_rec_loss: false predict_discount: true layer_norm: ${..layer_norm} diff --git a/rl_sandbox/utils/fc_nn.py b/rl_sandbox/utils/fc_nn.py index c20f993..ebf6320 100644 --- a/rl_sandbox/utils/fc_nn.py +++ b/rl_sandbox/utils/fc_nn.py @@ -8,7 +8,7 @@ def fc_nn_generator(input_num: int, intermediate_activation: t.Type[nn.Module] = nn.ReLU, final_activation: nn.Module = nn.Identity(), layer_norm: bool = False): - norm_layer = nn.Identity if layer_norm else nn.LayerNorm + norm_layer = nn.LayerNorm if layer_norm else nn.Identity assert num_layers >= 3 layers = [] layers.append(nn.Linear(input_num, hidden_size)) From ef2813087fa58bb03104a7c03d9ce8b58ddab336 Mon Sep 17 00:00:00 2001 From: Midren Date: Mon, 3 Jul 2023 22:02:20 +0100 Subject: [PATCH 076/106] Moved the reward transformation inside dreamer for correct tensorboard logging --- rl_sandbox/agents/dreamer_v2.py | 11 ++++++++++- rl_sandbox/config/agent/dreamer_v2.yaml | 1 + rl_sandbox/config/agent/dreamer_v2_crafter.yaml | 3 ++- .../config/agent/dreamer_v2_crafter_slotted.yaml | 1 + .../config/agent/dreamer_v2_slotted_attention.yaml | 1 + .../config/agent/dreamer_v2_slotted_combined.yaml | 1 + rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml | 1 + rl_sandbox/train.py | 10 +++++----- rl_sandbox/utils/env.py | 3 +-- 9 files changed, 23 insertions(+), 9 deletions(-) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 42d407b..6d9e407 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -3,6 +3,7 @@ import numpy as np import torch +from torch import nn from torch.nn import functional as F import torchvision as tv from unpackable import unpack @@ -20,6 +21,7 @@ class DreamerV2(RlAgent): def __init__( self, obs_space_num: list[int], # NOTE: encoder/decoder will work only with 64x64 currently + clip_rewards: str, actions_num: int, world_model: t.Any, actor: t.Any, @@ -40,6 +42,13 @@ def __init__( self.imagination_horizon = imagination_horizon self.actions_num = actions_num self.is_discrete = (action_type != 'continuous') + match clip_rewards: + case 'identity': + self.reward_clipper = nn.Identity() + case 'tanh': + self.reward_clipper = nn.Tanh() + case _: + raise RuntimeError('Invalid reward clipping') self.is_f16 = f16_precision self.world_model: WorldModel = world_model(actions_num=actions_num).to(device_type) @@ -95,7 +104,7 @@ def preprocess(self, rollout: Rollout): additional = self.world_model.precalc_data(obs.to(self.device)) return Rollout(obs=obs, actions=rollout.actions, - rewards=rollout.rewards, + rewards=self.reward_clipper(rollout.rewards), is_finished=rollout.is_finished, is_first=rollout.is_first, additional_data=rollout.additional_data | additional) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index ee3cfc6..ad707b2 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -1,5 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 +clip_rewards: identity imagination_horizon: 15 batch_cluster_size: 50 layer_norm: false diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 26a8c91..fc164c8 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -1,5 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 +clip_rewards: tanh imagination_horizon: 15 batch_cluster_size: 50 layer_norm: true @@ -9,7 +10,7 @@ world_model: _partial_: true batch_cluster_size: ${..batch_cluster_size} latent_dim: 32 - latent_classes: 32 + latent_classes: ${.latent_dim} rssm_dim: 1024 kl_loss_scale: 1.0 kl_loss_balancing: 0.8 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml index 32b0596..ca46343 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -1,5 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 +clip_rewards: tanh imagination_horizon: 15 batch_cluster_size: 50 layer_norm: true diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index ea0cbd4..dbc7e4e 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -1,5 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 +clip_rewards: tanh imagination_horizon: 15 batch_cluster_size: 50 layer_norm: true diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index 582ce33..2240c80 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -1,5 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 +clip_rewards: tanh imagination_horizon: 15 batch_cluster_size: 50 layer_norm: true diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml index b205f32..bdce5e9 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_debug.yaml @@ -1,5 +1,6 @@ _target_: rl_sandbox.agents.DreamerV2 +clip_rewards: tanh imagination_horizon: 15 batch_cluster_size: 50 layer_norm: true diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 74a8ffd..475f8d8 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -59,11 +59,11 @@ def main(cfg: DictConfig): val_env: Env = hydra.utils.instantiate(cfg.env) # TOOD: Create maybe some additional validation env if cfg.env.task_name.startswith("Crafter"): - val_env.env = crafter.Recorder(val_env.env, - logger.log_dir(), - save_stats=True, - save_video=False, - save_episode=False) + env.env = crafter.Recorder(env.env, + logger.log_dir(), + save_stats=True, + save_video=False, + save_episode=False) is_discrete = isinstance(env.action_space, Discrete) agent = hydra.utils.instantiate( diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 515dabb..39b1eda 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -160,8 +160,7 @@ def _step(self, action: Action, repeat_num: int) -> EnvStepResult: env_res = EnvStepResult(new_state, reward, terminated) else: env_res = ts - # FIXME: move to config the option - env_res.reward = np.tanh(rew + (env_res.reward or 0.0)) + env_res.reward = rew + (env_res.reward or 0.0) return env_res def reset(self): From 60fa7c4da9bd416cbcd8942744d648a6b52480f3 Mon Sep 17 00:00:00 2001 From: Midren Date: Tue, 4 Jul 2023 12:13:03 +0100 Subject: [PATCH 077/106] Added wandb integration --- pyproject.toml | 2 + .../agent/dreamer_v2_slotted_attention.yaml | 2 +- .../agent/dreamer_v2_slotted_combined.yaml | 2 +- rl_sandbox/config/config.yaml | 10 ++-- rl_sandbox/config/config_attention.yaml | 4 +- rl_sandbox/config/config_combined.yaml | 19 ++---- rl_sandbox/config/config_default.yaml | 4 +- rl_sandbox/metrics.py | 60 ++++++++++++------- rl_sandbox/train.py | 13 ++-- rl_sandbox/utils/logger.py | 38 +++++++++++- 10 files changed, 103 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 256f066..bd601c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ jaxtyping = '^0.2.0' lovely_tensors = '^0.1.10' torchshow = '^0.5.0' crafter = '^1.8.0' +wandb = '*' +flatten-dict = '*' hydra-joblib-launcher = "*" [tool.yapf] diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index dbc7e4e..2804e80 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -11,7 +11,7 @@ world_model: batch_cluster_size: ${..batch_cluster_size} latent_dim: 32 latent_classes: ${.latent_dim} - rssm_dim: 512 + rssm_dim: 1024 slots_num: 6 slots_iter_num: 2 kl_loss_scale: 1e2 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index 2240c80..80920d3 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -11,7 +11,7 @@ world_model: batch_cluster_size: ${..batch_cluster_size} latent_dim: 32 latent_classes: ${.latent_dim} - rssm_dim: 512 + rssm_dim: 1024 slots_num: 6 slots_iter_num: 2 kl_loss_scale: 1e2 # Try a bit higher to enforce more predicting the future instead of reconstruction diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 0e89318..30b64d9 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -2,6 +2,7 @@ defaults: - agent: dreamer_v2_slotted_combined - env: crafter - training: crafter + - logger: wandb - _self_ - override hydra/launcher: joblib @@ -9,23 +10,22 @@ seed: 42 device_type: cuda logger: - type: tensorboard - message: Crafter combined slotted + message: Crafter default log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 1e4 + val_logs_every: 2e4 validation: - rollout_num: 3 + rollout_num: 5 visualize: true metrics: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator _partial_: true debug: diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 5df0bdb..e105576 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -2,6 +2,7 @@ defaults: - agent: dreamer_v2_slotted_attention - env: crafter - training: crafter + - logger: wandb - _self_ - override hydra/launcher: joblib @@ -9,8 +10,7 @@ seed: 42 device_type: cuda logger: - type: tensorboard - message: Crafter attention slotted + message: Crafter attention 2x sched, fixed layer, dis f16 log_grads: false training: diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index b16cd88..55f70b0 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -1,6 +1,7 @@ defaults: - agent: dreamer_v2_slotted_combined - env: crafter + - logger: wandb - training: crafter - _self_ - override hydra/launcher: joblib @@ -9,17 +10,16 @@ seed: 42 device_type: cuda logger: - type: tensorboard - message: Crafter combined slotted 10x reconstruction + message: Combined, Fixed layer norm, add clamp log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 5e3 + val_logs_every: 2e4 validation: - rollout_num: 3 + rollout_num: 5 visualize: true metrics: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator @@ -32,13 +32,6 @@ debug: profiler: false hydra: - mode: MULTIRUN - #mode: RUN + mode: RUN launcher: - n_jobs: 8 - sweeper: - params: - agent.world_model.per_slot_rec_loss: true - agent.world_model.mask_combination: soft - agent.world_model.kl_loss_scale: 1e2,1e4 - agent.world_model.vit_l2_ratio: 0.1,0.01,0.9,0.99 + n_jobs: 1 diff --git a/rl_sandbox/config/config_default.yaml b/rl_sandbox/config/config_default.yaml index 910d50d..9a3d043 100644 --- a/rl_sandbox/config/config_default.yaml +++ b/rl_sandbox/config/config_default.yaml @@ -2,6 +2,7 @@ defaults: - agent: dreamer_v2_crafter - env: crafter - training: crafter + - logger: wandb - _self_ - override hydra/launcher: joblib @@ -9,14 +10,13 @@ seed: 42 device_type: cuda logger: - type: tensorboard message: Crafter default log_grads: false training: checkpoint_path: null steps: 1e6 - val_logs_every: 2e4 + val_logs_every: 5e4 validation: rollout_num: 5 diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index e946386..ed41f13 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -22,12 +22,12 @@ def on_step(self, logger): def on_episode(self, logger): pass - def on_val(self, logger, rollouts: list[Rollout]): + def on_val(self, logger, rollouts: list[Rollout], global_step: int): metrics = self.calculate_metrics(rollouts) - logger.log(metrics, self.episode, mode='val') + logger.log(metrics, global_step, mode='val') if self.log_video: video = rollouts[0].obs.unsqueeze(0) - logger.add_video('val/visualization', video.numpy() + 0.5, self.episode) + logger.add_video('val/visualization', ((video + 0.5) * 255).cpu().to(dtype=torch.uint8), global_step) self.episode += 1 def calculate_metrics(self, rollouts: list[Rollout]): @@ -67,7 +67,16 @@ def on_step(self, logger): def on_episode(self, logger): latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() - latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) + self.latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) + + self.reset_ep() + self.episode += 1 + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + self.viz_log(rollouts[0], logger, global_step) + + if self.episode == 0: + return # if discrete action space if self.agent.is_discrete: @@ -79,14 +88,8 @@ def on_episode(self, logger): else: # log mean +- std pass - logger.add_image('val/latent_probs', latent_hist, self.episode, dataformats='HW') - logger.add_image('val/latent_probs_sorted', np.sort(latent_hist, axis=1), self.episode, dataformats='HW') - - self.reset_ep() - self.episode += 1 - - def on_val(self, logger, rollouts: list[Rollout]): - self.viz_log(rollouts[0], logger, self.episode) + logger.add_image('val/latent_probs', self.latent_hist, global_step, dataformats='HW') + logger.add_image('val/latent_probs_sorted', np.sort(self.latent_hist, axis=1), global_step, dataformats='HW') def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): # obs = self.agent.preprocess_obs(obs) @@ -144,6 +147,7 @@ def viz_log(self, rollout, logger, epoch_num): videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) + videos_comparison = (videos_comparison.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() @@ -165,19 +169,28 @@ def on_episode(self, logger): mu = wm.slot_attention.slots_mu sigma = wm.slot_attention.slots_logsigma.exp() - mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) - sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + self.mu_hist = torch.mean((mu - mu.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + self.sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) + + + super().on_episode(logger) + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + super().on_val(logger, rollouts, global_step) + + if self.episode == 0: + return + + wm = self.agent.world_model if hasattr(wm.recurrent_model, 'last_attention'): logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention, self.episode, dataformats='HW') - logger.add_image('val/slot_attention_mu', mu_hist/mu_hist.max(), self.episode, dataformats='HW') - logger.add_image('val/slot_attention_sigma', sigma_hist/sigma_hist.max(), self.episode, dataformats='HW') - - logger.add_scalar('val/slot_attention_mu_diff_max', mu_hist.max(), self.episode) - logger.add_scalar('val/slot_attention_sigma_diff_max', sigma_hist.max(), self.episode) + logger.add_image('val/slot_attention_mu', self.mu_hist/self.mu_hist.max(), self.episode, dataformats='HW') + logger.add_image('val/slot_attention_sigma', self.sigma_hist/self.sigma_hist.max(), self.episode, dataformats='HW') - super().on_episode(logger) + logger.add_scalar('val/slot_attention_mu_diff_max', self.mu_hist.max(), self.episode) + logger.add_scalar('val/slot_attention_sigma_diff_max', self.sigma_hist.max(), self.episode) def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): # obs = torch.from_numpy(obs.copy()).to(self.agent.device) @@ -258,9 +271,13 @@ def viz_log(self, rollout, logger, epoch_num): videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) + videos_comparison = (videos_comparison.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + slots_video = (slots_video.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) + (videos_comparison * 255).to() + rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) @@ -367,6 +384,9 @@ def viz_log(self, rollout, logger, epoch_num): vit_masks_video = vit_masks_video.permute((0, 2, 3, 1, 4)) vit_masks_video = slots_video.reshape(*vit_masks_video.shape[:-2], -1).unsqueeze(0) + videos_comparison = (videos_comparison.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + slots_video = (slots_video.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + vit_masks_video = (vit_masks_video.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) logger.add_video('val/dreamed_vit_masks', vit_masks_video, epoch_num) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 475f8d8..859d1c2 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -1,6 +1,7 @@ import random import os os.environ['MUJOCO_GL'] = 'egl' +os.environ["WANDB_MODE"]="offline" import crafter import hydra @@ -22,13 +23,13 @@ iter_rollout) -def val_logs(agent, val_cfg: DictConfig, metrics, env: Env, logger: Logger): +def val_logs(agent, val_cfg: DictConfig, metrics, env: Env, logger: Logger, global_step: int): with torch.no_grad(): rollouts = collect_rollout_num(env, val_cfg.rollout_num, agent, collect_obs=True) rollouts = [agent.preprocess(r) for r in rollouts] - for metric in metrics: - metric.on_val(logger, rollouts) + for metric in metrics: + metric.on_val(logger, rollouts, global_step) @hydra.main(version_base="1.2", config_path='config', config_name='config') @@ -53,7 +54,7 @@ def main(cfg: DictConfig): # TODO: Implement smarter techniques for exploration # (Plan2Explore, etc) print(f'Start run: {cfg.logger.message}') - logger = Logger(**cfg.logger) + logger = Logger(**cfg.logger, cfg=cfg) env: Env = hydra.utils.instantiate(cfg.env) val_env: Env = hydra.utils.instantiate(cfg.env) @@ -104,7 +105,7 @@ def main(cfg: DictConfig): losses = agent.train(rollout_chunks) logger.log(losses, i, mode='pre_train') - val_logs(agent, cfg.validation, metrics, val_env, logger) + val_logs(agent, cfg.validation, metrics, val_env, logger, -1) if cfg.training.checkpoint_path is not None: prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) @@ -144,7 +145,7 @@ def main(cfg: DictConfig): ### Validation if (global_step % cfg.training.val_logs_every) <= (prev_global_step % cfg.training.val_logs_every): - val_logs(agent, cfg.validation, metrics, val_env, logger) + val_logs(agent, cfg.validation, metrics, val_env, logger, global_step) ### Checkpoint if (global_step % cfg.training.save_checkpoint_every) < ( diff --git a/rl_sandbox/utils/logger.py b/rl_sandbox/utils/logger.py index a54f166..5ff9b97 100644 --- a/rl_sandbox/utils/logger.py +++ b/rl_sandbox/utils/logger.py @@ -1,5 +1,8 @@ from torch.utils.tensorboard.writer import SummaryWriter +import wandb import typing as t +import omegaconf +from flatten_dict import flatten class SummaryWriterMock(): @@ -21,17 +24,50 @@ def add_histogram(*args, **kwargs): def add_figure(*args, **kwargs): pass +class WandbWriter(): + def __init__(self, project: str, comment: str, cfg: t.Optional[omegaconf.DictConfig]): + self.run = wandb.init( + project=project, + name=comment, + notes=comment, + config=flatten(omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), reducer=lambda x, y: f"{x}-{y}" if x is not None else y) if cfg else None + ) + self.log_dir = wandb.run.dir + + def add_scalar(self, name: str, value: t.Any, global_step: int): + wandb.log({name: value}, step=global_step) + + def add_image(self, name: str, image: t.Any, global_step: int, dataformats: str = 'CHW'): + match dataformats: + case "CHW": + mode = "RGB" + case "HW": + mode = "L" + case _: + raise RuntimeError("Not supported dataformat") + wandb.log({name: wandb.Image(image, mode=mode)}, step=global_step) + + def add_video(self, name: str, video: t.Any, global_step: int, fps: int): + wandb.log({name: wandb.Video(video[0], fps=fps)}, step=global_step) + + def add_figure(self, name: str, figure: t.Any, global_step: int): + wandb.log({name: wandb.Image(figure)}, step=global_step) class Logger: def __init__(self, type: t.Optional[str], + cfg: t.Optional[omegaconf.DictConfig] = None, + project: t.Optional[str] = None, message: t.Optional[str] = None, log_grads: bool = True, log_dir: t.Optional[str] = None ) -> None: self.type = type + msg = message or "" match type: case "tensorboard": - self.writer = SummaryWriter(comment=message or "", log_dir=log_dir) + self.writer = SummaryWriter(comment=msg, log_dir=log_dir) + case "wandb": + self.writer = WandbWriter(project=project, comment=msg, cfg=cfg) case None: self.writer = SummaryWriterMock() case _: From 4cac3deed24fcb91f30299417461b5d17c35b995 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 9 Jul 2023 10:36:20 +0100 Subject: [PATCH 078/106] Added Atari support and fixed encode_vit for combined --- rl_sandbox/agents/dreamer/vision.py | 2 +- rl_sandbox/agents/dreamer/world_model.py | 7 +- .../agents/dreamer/world_model_slots.py | 7 +- .../dreamer/world_model_slots_attention.py | 7 +- .../dreamer/world_model_slots_combined.py | 37 +++++-- rl_sandbox/agents/dreamer_v2.py | 21 +++- rl_sandbox/config/agent/dreamer_v2.yaml | 3 +- rl_sandbox/config/agent/dreamer_v2_atari.yaml | 27 +++++ .../config/agent/dreamer_v2_crafter.yaml | 49 +--------- .../agent/dreamer_v2_crafter_slotted.yaml | 66 +------------ .../agent/dreamer_v2_slotted_attention.yaml | 77 ++------------- .../agent/dreamer_v2_slotted_combined.yaml | 76 ++------------ rl_sandbox/config/config.yaml | 16 ++- rl_sandbox/config/config_attention.yaml | 2 +- rl_sandbox/config/config_combined.yaml | 19 +++- rl_sandbox/config/config_default.yaml | 2 +- rl_sandbox/config/config_slotted_debug.yaml | 21 ++-- rl_sandbox/config/env/atari.yaml | 8 ++ rl_sandbox/config/env/atari_amidar.yaml | 3 + rl_sandbox/config/env/atari_asterix.yaml | 3 + .../config/env/atari_chopper_command.yaml | 3 + rl_sandbox/config/env/atari_demon_attack.yaml | 3 + rl_sandbox/config/env/atari_freeway.yaml | 3 + rl_sandbox/config/env/atari_private_eye.yaml | 3 + rl_sandbox/config/env/atari_venture.yaml | 3 + .../config/env/atari_video_pinball.yaml | 3 + rl_sandbox/config/logger/tensorboard.yaml | 1 + rl_sandbox/config/logger/wandb.yaml | 2 + rl_sandbox/config/training/atari.yaml | 9 ++ rl_sandbox/metrics.py | 98 ++++++------------- rl_sandbox/train.py | 4 +- rl_sandbox/utils/env.py | 44 +++++++++ 32 files changed, 272 insertions(+), 357 deletions(-) create mode 100644 rl_sandbox/config/agent/dreamer_v2_atari.yaml create mode 100644 rl_sandbox/config/env/atari.yaml create mode 100644 rl_sandbox/config/env/atari_amidar.yaml create mode 100644 rl_sandbox/config/env/atari_asterix.yaml create mode 100644 rl_sandbox/config/env/atari_chopper_command.yaml create mode 100644 rl_sandbox/config/env/atari_demon_attack.yaml create mode 100644 rl_sandbox/config/env/atari_freeway.yaml create mode 100644 rl_sandbox/config/env/atari_private_eye.yaml create mode 100644 rl_sandbox/config/env/atari_venture.yaml create mode 100644 rl_sandbox/config/env/atari_video_pinball.yaml create mode 100644 rl_sandbox/config/logger/tensorboard.yaml create mode 100644 rl_sandbox/config/logger/wandb.yaml create mode 100644 rl_sandbox/config/training/atari.yaml diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index dcce3bf..00a5d76 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -9,11 +9,11 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[4, 4, 4], double_conv=False, flatten_output=True, + in_channels=3, ): super().__init__() layers = [] - in_channels = 3 for i, k in enumerate(kernel_sizes): out_channels = 2**i * channel_step layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 718ac35..15140ae 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -17,10 +17,11 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float): super().__init__() self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim @@ -166,7 +167,7 @@ def KL(dist1, dist2, free_nat = True): else: kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() - return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) priors = [] posteriors = [] @@ -226,6 +227,6 @@ def KL(dist1, dist2, free_nat = True): metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + - losses['loss_kl_reg'] + losses['loss_discount_pred']) + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 9636db3..6797c1b 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -18,13 +18,14 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, mask_combination: str = 'soft'): super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim @@ -236,7 +237,7 @@ def KL(dist1, dist2): td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) - return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) priors = [] posteriors = [] @@ -321,7 +322,7 @@ def KL(dist1, dist2): metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + - losses['loss_kl_reg'] + losses['loss_discount_pred']) + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index cc62fbf..7759d2f 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -18,7 +18,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, full_qk_from: int = 1, @@ -29,6 +29,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim @@ -244,7 +245,7 @@ def KL(dist1, dist2): td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) - return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) priors = [] posteriors = [] @@ -373,7 +374,7 @@ def KL(dist1, dist2): metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + - losses['loss_kl_reg'] + losses['loss_discount_pred']) + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 900bf7f..cd3b35d 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -7,7 +7,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer import Dist, Normalizer, View from rl_sandbox.agents.dreamer.rssm_slots_combined import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder from rl_sandbox.utils.dists import DistLayer @@ -19,7 +19,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, - actions_num, kl_loss_scale, kl_loss_balancing, kl_free_nats, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, mask_combination: str = 'soft', @@ -27,6 +27,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale self.kl_beta = kl_loss_scale self.rssm_dim = rssm_dim @@ -70,16 +71,27 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.dino_vit.requires_grad_(False) if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, 8, 8)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + double_conv=False, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) self.encoder = nn.Sequential( self.dino_vit, - nn.Flatten(), + self.post_vit + ) + # nn.Flatten(), # fc_nn_generator(64*self.dino_vit.feat_dim, # 64*384, # hidden_size=400, # num_layers=5, # intermediate_activation=nn.ELU, # layer_norm=layer_norm) - ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], @@ -88,7 +100,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + if self.encode_vit: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) + else: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), @@ -214,7 +229,11 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 - embed = self.encoder(obs) + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + # embed = self.encoder(obs) + else: + embed = self.encoder(obs) embed_with_pos_enc = self.positional_augmenter_inp(embed) # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) @@ -240,7 +259,7 @@ def KL(dist1, dist2): td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) - return (self.kl_beta * (self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) priors = [] posteriors = [] @@ -313,7 +332,7 @@ def KL(dist1, dist2): x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() else: - img_rec = 0 + img_rec = torch.tensor(0, device=obs.device) decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) decoded_imgs_detached = decoded_imgs_detached * img_mask @@ -362,7 +381,7 @@ def KL(dist1, dist2): metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + - losses['loss_kl_reg'] + losses['loss_discount_pred']) + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) return losses, posterior, metrics diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 6d9e407..e051c27 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -58,7 +58,8 @@ def __init__( self.critic: ImaginativeCritic = critic(latent_dim=self.world_model.state_size).to(device_type) self.world_model_optimizer = wm_optim(model=self.world_model, scaler=self.is_f16) - self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor, scaler=self.is_f16) + if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: + self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor, scaler=self.is_f16) self.actor_optimizer = actor_optim(model=self.actor) self.critic_optimizer = critic_optim(model=self.critic) @@ -110,18 +111,32 @@ def preprocess(self, rollout: Rollout): additional_data=rollout.additional_data | additional) def preprocess_obs(self, obs: torch.Tensor): - # FIXME: move to dataloader in replay buffer order = list(range(len(obs.shape))) # Swap channel from last to 3 from last order = order[:-3] + [order[-1]] + order[-3:-1] if self.world_model.encode_vit: ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) - return ToTensor(obs.type(torch.float32).permute(order)) + return ToTensor(obs.type(torch.float32).permute(order) / 255.0) else: return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) # return obs.type(torch.float32).permute(order) + def unprocess_obs(self, obs: torch.Tensor): + order = list(range(len(obs.shape))) + # # Swap channel from last to 3 from last + order = order[:-3] + order[-2:] + [order[-3]] + if self.world_model.encode_vit: + fromTensor = tv.transforms.Compose([ tv.transforms.Normalize(mean = [ 0., 0., 0. ], + std = [ 1/0.229, 1/0.224, 1/0.225 ]), + tv.transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], + std = [ 1., 1., 1. ]), + ]) + return (fromTensor(obs).clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + else: + return ((obs + 0.5).clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) + # return obs.type(torch.float32).permute(order) + def get_action(self, obs: Observation) -> Action: obs = torch.from_numpy(obs).to(self.device) obs = self.preprocess_obs(obs) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index ad707b2..385d85d 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -12,6 +12,7 @@ world_model: latent_dim: 32 latent_classes: 32 rssm_dim: 200 + discount_loss_scale: 1.0 kl_loss_scale: 1e1 kl_loss_balancing: 0.8 kl_free_nats: 0.00 @@ -34,7 +35,7 @@ actor: critic: _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic _partial_: true - discount_factor: 0.999 + discount_factor: 0.99 update_interval: 100 # [0-1], 1 means hard update soft_update_fraction: 1 diff --git a/rl_sandbox/config/agent/dreamer_v2_atari.yaml b/rl_sandbox/config/agent/dreamer_v2_atari.yaml new file mode 100644 index 0000000..8f080e2 --- /dev/null +++ b/rl_sandbox/config/agent/dreamer_v2_atari.yaml @@ -0,0 +1,27 @@ +defaults: + - dreamer_v2 + - _self_ + +clip_rewards: tanh +layer_norm: true + +world_model: + rssm_dim: 600 + kl_loss_scale: 0.1 + discount_loss_scale: 5.0 + predict_discount: true + +actor: + entropy_scale: 1e-3 + +critic: + discount_factor: 0.999 + +wm_optim: + lr: 2e-4 + +actor_optim: + lr: 4e-5 + +critic_optim: + lr: 1e-4 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index fc164c8..8839422 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -1,68 +1,25 @@ -_target_: rl_sandbox.agents.DreamerV2 +defaults: + - dreamer_v2 + - _self_ clip_rewards: tanh -imagination_horizon: 15 -batch_cluster_size: 50 layer_norm: true world_model: - _target_: rl_sandbox.agents.dreamer.world_model.WorldModel - _partial_: true - batch_cluster_size: ${..batch_cluster_size} - latent_dim: 32 - latent_classes: ${.latent_dim} rssm_dim: 1024 - kl_loss_scale: 1.0 - kl_loss_balancing: 0.8 - kl_free_nats: 0.00 - discrete_rssm: false - decode_vit: false - vit_l2_ratio: 1.0 - encode_vit: false predict_discount: true - layer_norm: ${..layer_norm} actor: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor - _partial_: true - # mixing of reinforce and maximizing value func - # for dm_control it is zero in Dreamer (Atari 1) - reinforce_fraction: null entropy_scale: 3e-3 - layer_norm: ${..layer_norm} critic: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic - _partial_: true discount_factor: 0.999 - update_interval: 100 - # [0-1], 1 means hard update - soft_update_fraction: 1 - # Lambda parameter for trainin deeper multi-step prediction - value_target_lambda: 0.95 - layer_norm: ${..layer_norm} wm_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr_scheduler: null lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 actor_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 critic_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml index ca46343..2ca3961 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -1,74 +1,14 @@ -_target_: rl_sandbox.agents.DreamerV2 - -clip_rewards: tanh -imagination_horizon: 15 -batch_cluster_size: 50 -layer_norm: true +defaults: + - dreamer_v2_crafter + - _self_ world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots.WorldModel - _partial_: true - batch_cluster_size: ${..batch_cluster_size} - latent_dim: 32 - latent_classes: ${.latent_dim} rssm_dim: 512 slots_num: 6 slots_iter_num: 2 kl_loss_scale: 1e2 - kl_loss_balancing: 0.8 - kl_free_nats: 0.00 - discrete_rssm: false decode_vit: true use_prev_slots: false vit_l2_ratio: 0.1 encode_vit: false - predict_discount: true - layer_norm: ${..layer_norm} - -actor: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor - _partial_: true - # mixing of reinforce and maximizing value func - # for dm_control it is zero in Dreamer (Atari 1) - reinforce_fraction: null - entropy_scale: 3e-3 - layer_norm: ${..layer_norm} - -critic: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic - _partial_: true - discount_factor: 0.999 - update_interval: 100 - # [0-1], 1 means hard update - soft_update_fraction: 1 - # Lambda parameter for trainin deeper multi-step prediction - value_target_lambda: 0.95 - layer_norm: ${..layer_norm} - -wm_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr_scheduler: - - _target_: rl_sandbox.utils.optimizer.WarmupScheduler - _partial_: true - warmup_steps: 1e3 - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -actor_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -critic_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index 2804e80..a0c184c 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -1,77 +1,20 @@ -_target_: rl_sandbox.agents.DreamerV2 - -clip_rewards: tanh -imagination_horizon: 15 -batch_cluster_size: 50 -layer_norm: true +defaults: + - dreamer_v2_crafter_slotted + - _self_ world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel - _partial_: true - batch_cluster_size: ${..batch_cluster_size} - latent_dim: 32 - latent_classes: ${.latent_dim} - rssm_dim: 1024 + rssm_dim: 512 slots_num: 6 slots_iter_num: 2 kl_loss_scale: 1e2 - kl_loss_balancing: 0.8 - kl_free_nats: 0.00 - discrete_rssm: false + encode_vit: false decode_vit: true - full_qk_from: 4e4 - symmetric_qk: true + mask_combination: soft use_prev_slots: false - attention_block_num: 1 + per_slot_rec_loss: false vit_l2_ratio: 0.1 - encode_vit: false - predict_discount: true - layer_norm: ${..layer_norm} - -actor: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor - _partial_: true - # mixing of reinforce and maximizing value func - # for dm_control it is zero in Dreamer (Atari 1) - reinforce_fraction: null - entropy_scale: 3e-3 - layer_norm: ${..layer_norm} -critic: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic - _partial_: true - discount_factor: 0.999 - update_interval: 100 - # [0-1], 1 means hard update - soft_update_fraction: 1 - # Lambda parameter for trainin deeper multi-step prediction - value_target_lambda: 0.95 - layer_norm: ${..layer_norm} - -wm_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr_scheduler: - - _target_: rl_sandbox.utils.optimizer.WarmupScheduler - _partial_: true - warmup_steps: 1e3 - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -actor_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -critic_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 + full_qk_from: 4e4 + symmetric_qk: true + attention_block_num: 1 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index 80920d3..a6c2bef 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -1,76 +1,16 @@ -_target_: rl_sandbox.agents.DreamerV2 - -clip_rewards: tanh -imagination_horizon: 15 -batch_cluster_size: 50 -layer_norm: true +defaults: + - dreamer_v2_crafter_slotted + - _self_ world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel - _partial_: true - batch_cluster_size: ${..batch_cluster_size} - latent_dim: 32 - latent_classes: ${.latent_dim} - rssm_dim: 1024 + rssm_dim: 512 slots_num: 6 slots_iter_num: 2 - kl_loss_scale: 1e2 # Try a bit higher to enforce more predicting the future instead of reconstruction - kl_loss_balancing: 0.8 - kl_free_nats: 0.00 - discrete_rssm: false - decode_vit: true - use_prev_slots: false - vit_l2_ratio: 0.1 + kl_loss_scale: 1e2 encode_vit: false + decode_vit: true mask_combination: soft + use_prev_slots: false per_slot_rec_loss: false - predict_discount: true - layer_norm: ${..layer_norm} - -actor: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeActor - _partial_: true - # mixing of reinforce and maximizing value func - # for dm_control it is zero in Dreamer (Atari 1) - reinforce_fraction: null - entropy_scale: 3e-3 - layer_norm: ${..layer_norm} - -critic: - _target_: rl_sandbox.agents.dreamer.ac.ImaginativeCritic - _partial_: true - discount_factor: 0.999 - update_interval: 100 - # [0-1], 1 means hard update - soft_update_fraction: 1 - # Lambda parameter for trainin deeper multi-step prediction - value_target_lambda: 0.95 - layer_norm: ${..layer_norm} - -wm_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr_scheduler: - - _target_: rl_sandbox.utils.optimizer.WarmupScheduler - _partial_: true - warmup_steps: 1e3 - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -actor_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 - -critic_optim: - _target_: rl_sandbox.utils.optimizer.Optimizer - _partial_: true - lr: 1e-4 - eps: 1e-5 - weight_decay: 1e-6 - clip: 100 + vit_l2_ratio: 0.1 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 30b64d9..edc7788 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,8 +1,8 @@ defaults: - - agent: dreamer_v2_slotted_combined - - env: crafter - - training: crafter - - logger: wandb + - agent: dreamer_v2_atari + - env: atari_freeway + - training: atari + - logger: tensorboard - _self_ - override hydra/launcher: joblib @@ -10,7 +10,7 @@ seed: 42 device_type: cuda logger: - message: Crafter default + message: Atari with default dreamer log_grads: false training: @@ -38,8 +38,4 @@ hydra: n_jobs: 8 sweeper: params: - agent.world_model.per_slot_rec_loss: false,true - agent.world_model.mask_combination: soft,hard - agent.world_model.kl_loss_scale: 1e3 - agent.world_model.vit_l2_ratio: 0.9 - + env.task_name: amidar,asterix,chopper_command,demon_attack,freeway,private_eye,venture,video_pinball diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index e105576..7cd7b5b 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -39,6 +39,6 @@ hydra: # n_jobs: 8 #sweeper: # params: - # agent.world_model.full_qk_from: 1,1e6 + # agent.world_model.full_qk_from: 1,2e4 # agent.world_model.symmetric_qk: true,false # agent.world_model.attention_block_num: 1,3 diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index 55f70b0..d242a40 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -1,7 +1,7 @@ defaults: - agent: dreamer_v2_slotted_combined - env: crafter - - logger: wandb + - logger: tensorboard - training: crafter - _self_ - override hydra/launcher: joblib @@ -9,8 +9,15 @@ defaults: seed: 42 device_type: cuda +agent: + world_model: + #encode_vit: true + decode_vit: false + #vit_l2_ratio: 1.0 + kl_loss_scale: 1e2 + logger: - message: Combined, Fixed layer norm, add clamp + message: Combined, without dino log_grads: false training: @@ -32,6 +39,14 @@ debug: profiler: false hydra: + #mode: MULTIRUN mode: RUN launcher: n_jobs: 1 + #sweeper: + # params: + # agent.world_model.slots_num: 3,6 + # agent.world_model.per_slot_rec_loss: true + # agent.world_model.mask_combination: soft,hard + # agent.world_model.kl_loss_scale: 1e2 + # agent.world_model.vit_l2_ratio: 0.1,1e-3 diff --git a/rl_sandbox/config/config_default.yaml b/rl_sandbox/config/config_default.yaml index 9a3d043..0607307 100644 --- a/rl_sandbox/config/config_default.yaml +++ b/rl_sandbox/config/config_default.yaml @@ -16,7 +16,7 @@ logger: training: checkpoint_path: null steps: 1e6 - val_logs_every: 5e4 + val_logs_every: 2e4 validation: rollout_num: 5 diff --git a/rl_sandbox/config/config_slotted_debug.yaml b/rl_sandbox/config/config_slotted_debug.yaml index e717bf3..6388941 100644 --- a/rl_sandbox/config/config_slotted_debug.yaml +++ b/rl_sandbox/config/config_slotted_debug.yaml @@ -1,6 +1,7 @@ defaults: - - agent: dreamer_v2_crafter_slotted + - agent: dreamer_v2_slotted_combined - env: crafter + - logger: tensorboard - training: crafter - _self_ - override hydra/launcher: joblib @@ -8,9 +9,14 @@ defaults: seed: 42 device_type: cuda +agent: + world_model: + encode_vit: true + vit_l2_ratio: 1.0 + kl_loss_scale: 1e4 + logger: - type: tensorboard - message: Crafter 6 DINO slots + message: Combined encode vit log_grads: false training: @@ -25,7 +31,7 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator _partial_: true debug: @@ -38,6 +44,9 @@ hydra: n_jobs: 1 #sweeper: # params: - # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 - # agent.world_model.vit_l2_ratio: 0.1,0.9 + # agent.world_model.slots_num: 3,6 + # agent.world_model.per_slot_rec_loss: true + # agent.world_model.mask_combination: soft,hard + # agent.world_model.kl_loss_scale: 1e2 + # agent.world_model.vit_l2_ratio: 0.1,1e-3 diff --git a/rl_sandbox/config/env/atari.yaml b/rl_sandbox/config/env/atari.yaml new file mode 100644 index 0000000..e4ac2d4 --- /dev/null +++ b/rl_sandbox/config/env/atari.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.AtariEnv +task_name: daemon_attack +sticky: true +obs_res: [64, 64] +repeat_action_num: 1 +life_done: false +greyscale: false +transforms: [] diff --git a/rl_sandbox/config/env/atari_amidar.yaml b/rl_sandbox/config/env/atari_amidar.yaml new file mode 100644 index 0000000..657132c --- /dev/null +++ b/rl_sandbox/config/env/atari_amidar.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: amidar diff --git a/rl_sandbox/config/env/atari_asterix.yaml b/rl_sandbox/config/env/atari_asterix.yaml new file mode 100644 index 0000000..8618320 --- /dev/null +++ b/rl_sandbox/config/env/atari_asterix.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: asterix diff --git a/rl_sandbox/config/env/atari_chopper_command.yaml b/rl_sandbox/config/env/atari_chopper_command.yaml new file mode 100644 index 0000000..12ced33 --- /dev/null +++ b/rl_sandbox/config/env/atari_chopper_command.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: chopper_command diff --git a/rl_sandbox/config/env/atari_demon_attack.yaml b/rl_sandbox/config/env/atari_demon_attack.yaml new file mode 100644 index 0000000..3239984 --- /dev/null +++ b/rl_sandbox/config/env/atari_demon_attack.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: demon_attack diff --git a/rl_sandbox/config/env/atari_freeway.yaml b/rl_sandbox/config/env/atari_freeway.yaml new file mode 100644 index 0000000..9e1555c --- /dev/null +++ b/rl_sandbox/config/env/atari_freeway.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: freeway diff --git a/rl_sandbox/config/env/atari_private_eye.yaml b/rl_sandbox/config/env/atari_private_eye.yaml new file mode 100644 index 0000000..67d16a6 --- /dev/null +++ b/rl_sandbox/config/env/atari_private_eye.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: private_eye diff --git a/rl_sandbox/config/env/atari_venture.yaml b/rl_sandbox/config/env/atari_venture.yaml new file mode 100644 index 0000000..f39acc3 --- /dev/null +++ b/rl_sandbox/config/env/atari_venture.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: venture diff --git a/rl_sandbox/config/env/atari_video_pinball.yaml b/rl_sandbox/config/env/atari_video_pinball.yaml new file mode 100644 index 0000000..7e2b8dc --- /dev/null +++ b/rl_sandbox/config/env/atari_video_pinball.yaml @@ -0,0 +1,3 @@ +defaults: + - atari +task_name: video_pinball diff --git a/rl_sandbox/config/logger/tensorboard.yaml b/rl_sandbox/config/logger/tensorboard.yaml new file mode 100644 index 0000000..8540962 --- /dev/null +++ b/rl_sandbox/config/logger/tensorboard.yaml @@ -0,0 +1 @@ +type: tensorboard diff --git a/rl_sandbox/config/logger/wandb.yaml b/rl_sandbox/config/logger/wandb.yaml new file mode 100644 index 0000000..d05c8be --- /dev/null +++ b/rl_sandbox/config/logger/wandb.yaml @@ -0,0 +1,2 @@ +type: wandb +project: slotted_dreamer diff --git a/rl_sandbox/config/training/atari.yaml b/rl_sandbox/config/training/atari.yaml new file mode 100644 index 0000000..272db19 --- /dev/null +++ b/rl_sandbox/config/training/atari.yaml @@ -0,0 +1,9 @@ +steps: 4e4 +prefill: 50000 +batch_size: 16 +f16_precision: false +pretrain: 1 +prioritize_ends: true +train_every: 4 +save_checkpoint_every: 5e5 +val_logs_every: 2e4 diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index ed41f13..b17aced 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -27,7 +27,7 @@ def on_val(self, logger, rollouts: list[Rollout], global_step: int): logger.log(metrics, global_step, mode='val') if self.log_video: video = rollouts[0].obs.unsqueeze(0) - logger.add_video('val/visualization', ((video + 0.5) * 255).cpu().to(dtype=torch.uint8), global_step) + logger.add_video('val/visualization', self.agent.unprocess_obs(video), global_step) self.episode += 1 def calculate_metrics(self, rollouts: list[Rollout]): @@ -99,21 +99,14 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ rews = [] state = None - means = np.array([0.485, 0.456, 0.406]) - stds = np.array([0.229, 0.224, 0.225]) - UnNormalize = tv.transforms.Normalize(list(-means/stds), - list(1/stds)) for idx, (o, a) in enumerate(list(zip(obs, actions))): if idx > update_num: break state = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), state) video_r = self.agent.world_model.image_predictor(state.combined).mode rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) - if self.agent.world_model.encode_vit: - video_r = UnNormalize(video_r) - else: - video_r = (video_r + 0.5) - video.append(video_r.clamp(0, 1)) + + video.append(self.agent.unprocess_obs(video_r)) rews = torch.Tensor(rews).to(obs.device) @@ -121,11 +114,8 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) video_r = self.agent.world_model.image_predictor(states.combined[1:]).mode.detach() - if self.agent.world_model.encode_vit: - video_r = UnNormalize(video_r) - else: - video_r = (video_r + 0.5) - video.append(video_r) + + video.append(self.agent.unprocess_obs(video_r)) return torch.cat(video), rews @@ -135,7 +125,8 @@ def viz_log(self, rollout, logger, epoch_num): videos = torch.cat([ rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces - ], dim=3) + 0.5 + ], dim=3) + videos = self.agent.unprocess_obs(videos) real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] @@ -145,9 +136,8 @@ def viz_log(self, rollout, logger, epoch_num): ]) videos_r = torch.cat(videos_r, dim=3) - videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) + videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) - videos_comparison = (videos_comparison.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() @@ -204,10 +194,6 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ state = None prev_slots = None - means = np.array([0.485, 0.456, 0.406]) - stds = np.array([0.229, 0.224, 0.225]) - UnNormalize = tv.transforms.Normalize(list(-means/stds), - list(1/stds)) for idx, (o, a) in enumerate(list(zip(obs, actions))): if idx > update_num: break @@ -221,12 +207,8 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ video_r = torch.sum(decoded_imgs, dim=1) rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) - if self.agent.world_model.encode_vit: - video_r = UnNormalize(video_r) - else: - video_r = (video_r + 0.5) - video.append(video_r.clamp(0, 1)) - slots_video.append((decoded_imgs + 0.5).clamp(0, 1)) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) rews = torch.Tensor(rews).to(obs.device) @@ -240,12 +222,8 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) - if self.agent.world_model.encode_vit: - video_r = UnNormalize(video_r) - else: - video_r = (video_r + 0.5) - video.append(video_r) - slots_video.append(decoded_imgs + 0.5) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) return torch.cat(video), rews, torch.cat(slots_video) @@ -255,7 +233,8 @@ def viz_log(self, rollout, logger, epoch_num): videos = torch.cat([ rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces - ], dim=3) + 0.5 + ], dim=3) + videos = self.agent.unprocess_obs(videos) real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] @@ -265,19 +244,15 @@ def viz_log(self, rollout, logger, epoch_num): ]) videos_r = torch.cat(videos_r, dim=3) + videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) slots_video = torch.cat(list(slots_video)[:3], dim=3) + slots_video = slots_video.permute((0, 2, 3, 1, 4)) slots_video = slots_video.reshape(*slots_video.shape[:-2], -1).unsqueeze(0) - videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) - - videos_comparison = (videos_comparison.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) - slots_video = (slots_video.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) - (videos_comparison * 255).to() - rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) @@ -297,10 +272,6 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ state = None prev_slots = None - means = np.array([0.485, 0.456, 0.406]) - stds = np.array([0.229, 0.224, 0.225]) - UnNormalize = tv.transforms.Normalize(list(-means/stds), - list(1/stds)) for idx, (o, a) in enumerate(list(zip(obs, actions))): if idx > update_num: break @@ -316,17 +287,14 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ _, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) upscale = tv.transforms.Resize(64, antialias=True) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) - if self.agent.world_model.encode_vit: - video_r = UnNormalize(video_r) - else: - video_r = (video_r + 0.5) - video.append(video_r.clamp(0, 1)) - slots_video.append((decoded_imgs + 0.5).clamp(0, 1)) - vit_slots_video.append((per_slot_vit/upscaled_mask.max() + 0.5).clamp(0, 1)) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + vit_slots_video.append(self.agent.unprocess_obs(per_slot_vit/upscaled_mask.max())) rews = torch.Tensor(rews).to(obs.device) @@ -347,14 +315,9 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ per_slot_vit = (upscaled_mask.unsqueeze(2) * obs[update_num+1:].to(self.agent.device).unsqueeze(1)) # per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) - if self.agent.world_model.encode_vit: - video_r = UnNormalize(video_r) - else: - video_r = (video_r + 0.5) - video.append(video_r.clamp(0, 1)) - slots_video.append((decoded_imgs + 0.5).clamp(0, 1)) - vit_slots_video = None # FIXME: this is not correct - # vit_slots_video.append(per_slot_vit/np.expand_dims(upscaled_mask.max(axis=(1,2,3)), axis=(1,2,3,4)) + 0.5) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + vit_slots_video.append(self.agent.unprocess_obs(per_slot_vit/torch.amax(upscaled_mask, dim=(1,2,3)).view(-1, 1, 1, 1, 1))) return torch.cat(video), rews, torch.cat(slots_video), torch.cat(vit_slots_video) @@ -364,29 +327,28 @@ def viz_log(self, rollout, logger, epoch_num): videos = torch.cat([ rollout.obs[init_idx:init_idx + self.agent.imagination_horizon] for init_idx in init_indeces - ], dim=3) + 0.5 + ], dim=3) + videos = self.agent.unprocess_obs(videos) real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] - videos_r, imagined_rewards, slots_video, vit_masks_video = zip(*[self._generate_video(obs_0.copy(), a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + videos_r, imagined_rewards, slots_video, vit_masks_video = zip(*[self._generate_video(obs_0, a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) ]) videos_r = torch.cat(videos_r, dim=3) + + videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) + slots_video = torch.cat(list(slots_video)[:3], dim=3) slots_video = slots_video.permute((0, 2, 3, 1, 4)) slots_video = slots_video.reshape(*slots_video.shape[:-2], -1).unsqueeze(0) - videos_comparison = torch.cat([videos, videos_r, torch.abs(videos - videos_r + 1)/2], dim=2).unsqueeze(0) - vit_masks_video = torch.cat(list(vit_masks_video)[:3], dim=3) vit_masks_video = vit_masks_video.permute((0, 2, 3, 1, 4)) - vit_masks_video = slots_video.reshape(*vit_masks_video.shape[:-2], -1).unsqueeze(0) + vit_masks_video = vit_masks_video.reshape(*vit_masks_video.shape[:-2], -1).unsqueeze(0) - videos_comparison = (videos_comparison.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) - slots_video = (slots_video.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) - vit_masks_video = (vit_masks_video.clamp(0, 1) * 255).cpu().to(dtype=torch.uint8) logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) logger.add_video('val/dreamed_vit_masks', vit_masks_video, epoch_num) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 859d1c2..0a9ad7e 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -105,7 +105,7 @@ def main(cfg: DictConfig): losses = agent.train(rollout_chunks) logger.log(losses, i, mode='pre_train') - val_logs(agent, cfg.validation, metrics, val_env, logger, -1) + val_logs(agent, cfg.validation, metrics, val_env, logger, 0) if cfg.training.checkpoint_path is not None: prev_global_step = global_step = agent.load_ckpt(cfg.training.checkpoint_path) @@ -129,7 +129,7 @@ def main(cfg: DictConfig): losses = agent.train(rollout_chunk) if cfg.debug.profiler: prof.step() - if global_step % 100 == 0: + if global_step % 1000 == 0: logger.log(losses, global_step, mode='train') for metric in metrics: diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 39b1eda..101a708 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -129,6 +129,50 @@ def action_space(self) -> gym.Space: space = t.transform_space(t) return space +class AtariEnv(Env): + + def __init__(self, task_name: str, obs_res: tuple[int, int], sticky: bool, life_done: bool, greyscale: bool, + repeat_action_num: int, transforms: list[ActionTransformer]): + import gym.wrappers + import gym.envs.atari + super().__init__(True, obs_res, repeat_action_num, transforms) + + self.env: gym.Env = gym.envs.atari.AtariEnv(game=task_name, obs_type='image', frameskip=1, repeat_action_probability=0.25 if sticky else 0, full_action_space=False) + # Tell wrapper that the inner env has no action repeat. + self.env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') + self.env = gym.wrappers.AtariPreprocessing(self.env, + 30, repeat_action_num, obs_res[0], + life_done, greyscale) + + + def render(self): + raise RuntimeError("Render is not supported for AtariEnv") + + def _step(self, action: Action, repeat_num: int) -> EnvStepResult: + rew = 0 + for _ in range(repeat_num - 1): + new_state, reward, terminated, _ = self.env.step(action) + ts = EnvStepResult(new_state, reward, terminated) + if terminated: + break + rew += reward or 0.0 + if repeat_num == 1 or not terminated: + new_state, reward, terminated, _ = self.env.step(action) + env_res = EnvStepResult(new_state, reward, terminated) + else: + env_res = ts + env_res.reward = rew + (env_res.reward or 0.0) + return env_res + + def reset(self): + state = self.env.reset() + return EnvStepResult(state, 0, False) + + def _observation_space(self): + return self.env.observation_space + + def _action_space(self): + return self.env.action_space class GymEnv(Env): From bdda68594510e2a604aeb56c697beffa65b5457c Mon Sep 17 00:00:00 2001 From: Midren Date: Wed, 12 Jul 2023 08:49:32 +0100 Subject: [PATCH 079/106] Added option for choosing 224 vs 64 dino features --- rl_sandbox/agents/dreamer/world_model.py | 75 ++++++++++------- .../agents/dreamer/world_model_slots.py | 82 +++++++++++-------- .../dreamer/world_model_slots_attention.py | 72 +++++++++------- .../dreamer/world_model_slots_combined.py | 52 +++++------- rl_sandbox/config/agent/dreamer_v2.yaml | 3 +- rl_sandbox/config/config.yaml | 21 +++-- rl_sandbox/utils/env.py | 2 +- 7 files changed, 170 insertions(+), 137 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 15140ae..e2396f6 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer import Dist, Normalizer, View from rl_sandbox.agents.dreamer.rssm import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder from rl_sandbox.utils.dists import DistLayer @@ -18,7 +18,8 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, - predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float): + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, + vit_l2_ratio: float, vit_img_size: int): super().__init__() self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.discount_scale = discount_loss_scale @@ -37,6 +38,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encode_vit = encode_vit self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size self.recurrent_model = RSSM(latent_dim, rssm_dim, @@ -45,24 +47,37 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, discrete_rssm, norm_layer=nn.LayerNorm if layer_norm else nn.Identity) if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - # self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim - self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + double_conv=False, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) self.encoder = nn.Sequential( self.dino_vit, - nn.Flatten(), - # fc_nn_generator(64*self.dino_vit.feat_dim, - # 64*384, - # hidden_size=400, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm) - ) + self.post_vit + ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4, 4], @@ -74,16 +89,9 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.dino_predictor = Decoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, channel_step=192, - kernel_sizes=[3, 4], + kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim, return_dist=True) - # self.dino_predictor = fc_nn_generator(self.state_size, - # 64*self.dino_vit.feat_dim, - # hidden_size=2048, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm, - # final_activation=DistLayer('mse')) self.image_predictor = Decoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity) @@ -107,11 +115,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) with torch.no_grad(): d_features = self.dino_vit(obs).cpu() @@ -146,7 +152,10 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 - embed = self.encoder(obs) + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) a_c = a.reshape(-1, self.cluster_size, self.actions_num) @@ -208,11 +217,15 @@ def KL(dist1, dist2, free_nat = True): img_rec = 0 x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() + d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + - (1-self.vit_l2_ratio) * img_rec) - # losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.flatten(1, 2)).float().mean() + - # (1-self.vit_l2_ratio) * img_rec) + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + d_rec = -d_pred.log_prob(d_obs).float().mean() + d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 6797c1b..38eb654 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer import Dist, Normalizer, View from rl_sandbox.agents.dreamer.rssm_slots import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder from rl_sandbox.utils.dists import DistLayer @@ -20,7 +20,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, mask_combination: str = 'soft'): super().__init__() self.use_prev_slots = use_prev_slots @@ -43,6 +43,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encode_vit = encode_vit self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size self.n_dim = 384 @@ -55,27 +56,36 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, norm_layer=nn.LayerNorm if layer_norm else nn.Identity, embed_size=self.n_dim) if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - # self.dino_vit = ViTFeat( - # "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", - # feat_dim=384, - # vit_arch='small', - # patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim - self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + double_conv=False, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) self.encoder = nn.Sequential( self.dino_vit, - nn.Flatten(), - # fc_nn_generator(64*self.dino_vit.feat_dim, - # 64*384, - # hidden_size=400, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm) + self.post_vit ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -85,8 +95,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) - # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) + if self.encode_vit: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) + else: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), nn.ReLU(inplace=True), @@ -96,17 +108,9 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, channel_step=192, - # kernel_sizes=[5, 5, 4], # for size 224x224 - kernel_sizes=[3, 4], + kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim+1, return_dist=False) - # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, - # 64*self.dino_vit.feat_dim, - # hidden_size=2048, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm, - # final_activation=DistLayer('mse')) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -146,11 +150,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) d_features = self.dino_vit(obs) return {'d_features': d_features} @@ -212,8 +214,11 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) embed_with_pos_enc = self.positional_augmenter_inp(embed) - # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) pre_slot_features = self.slot_mlp( embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) @@ -302,12 +307,17 @@ def KL(dist1, dist2): x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() - decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) + decoded_feats, masks = d_features.reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) feat_mask = self.slot_mask(masks) decoded_feats = decoded_feats * feat_mask d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) - losses['loss_reconstruction'] = (self.vit_l2_ratio * -d_pred.log_prob(d_features.reshape(b, self.vit_feat_dim, 8, 8)).float().mean() + - (1-self.vit_l2_ratio) * img_rec) + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + d_rec = -d_pred.log_prob(d_obs).float().mean() + d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 7759d2f..5bdfbde 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, Normalizer +from rl_sandbox.agents.dreamer import Dist, Normalizer, View from rl_sandbox.agents.dreamer.rssm_slots_attention import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder from rl_sandbox.utils.dists import DistLayer @@ -20,7 +20,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, full_qk_from: int = 1, symmetric_qk: bool = False, attention_block_num: int = 3, @@ -47,6 +47,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encode_vit = encode_vit self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size self.per_slot_rec_loss = per_slot_rec_loss self.n_dim = 384 @@ -63,27 +64,36 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, symmetric_qk=symmetric_qk, attention_block_num=attention_block_num) if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - # self.dino_vit = ViTFeat( - # "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", - # feat_dim=384, - # vit_arch='small', - # patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim - self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + double_conv=False, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) self.encoder = nn.Sequential( self.dino_vit, - nn.Flatten(), - # fc_nn_generator(64*self.dino_vit.feat_dim, - # 64*384, - # hidden_size=400, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm) + self.post_vit ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -93,7 +103,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + if self.encode_vit: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) + else: + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), @@ -108,13 +121,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, kernel_sizes=[3, 4], output_channels=self.vit_feat_dim+1, return_dist=False) - # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, - # 64*self.dino_vit.feat_dim, - # hidden_size=2048, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm, - # final_activation=DistLayer('mse')) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -154,11 +160,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) d_features = self.dino_vit(obs) return {'d_features': d_features} @@ -220,8 +224,11 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, b, _, h, w = obs.shape # s <- BxHxWx3 embed = self.encoder(obs) + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) embed_with_pos_enc = self.positional_augmenter_inp(embed) - # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) pre_slot_features = self.slot_mlp( embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) @@ -345,7 +352,7 @@ def KL(dist1, dist2): decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) feat_mask = self.slot_mask(masks) - d_obs = d_features.reshape(b, self.vit_feat_dim, 8, 8) + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) if self.per_slot_rec_loss: l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) @@ -358,7 +365,10 @@ def KL(dist1, dist2): d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() + d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec prior_logits = prior.stoch_logits posterior_logits = posterior.stoch_logits diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index cd3b35d..7d34f41 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -21,7 +21,7 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, - decode_vit: bool, vit_l2_ratio: float, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, + decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, use_prev_slots: bool = True, mask_combination: str = 'soft', per_slot_rec_loss: bool = False): super().__init__() @@ -45,6 +45,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encode_vit = encode_vit self.decode_vit = decode_vit self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size self.per_slot_rec_loss = per_slot_rec_loss self.n_dim = 384 @@ -59,20 +60,25 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, slots_num=slots_num, embed_size=self.n_dim) if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) - self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", feat_dim=384, vit_arch='small', patch_size=8) - # self.dino_vit = ViTFeat( - # "/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", - # feat_dim=384, - # vit_arch='small', - # patch_size=16) self.vit_feat_dim = self.dino_vit.feat_dim - self.vit_num_patches = self.dino_vit.model.patch_embed.num_patches self.dino_vit.requires_grad_(False) if encode_vit: self.post_vit = nn.Sequential( - View((-1, self.vit_feat_dim, 8, 8)), + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[2], channel_step=384, @@ -85,13 +91,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.dino_vit, self.post_vit ) - # nn.Flatten(), - # fc_nn_generator(64*self.dino_vit.feat_dim, - # 64*384, - # hidden_size=400, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], @@ -104,7 +103,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) - # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), nn.ReLU(inplace=True), @@ -118,13 +116,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, kernel_sizes=[3, 4], output_channels=self.vit_feat_dim+1, return_dist=False) - # self.dino_predictor = fc_nn_generator(rssm_dim + latent_dim*latent_classes, - # 64*self.dino_vit.feat_dim, - # hidden_size=2048, - # num_layers=5, - # intermediate_activation=nn.ELU, - # layer_norm=layer_norm, - # final_activation=DistLayer('mse')) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -164,11 +155,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: if not self.decode_vit: return {} if not self.encode_vit: - # ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), - # (0.229, 0.224, 0.225)), - # tv.transforms.Resize(224, antialias=True)]) - ToTensor = tv.transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) d_features = self.dino_vit(obs).squeeze() return {'d_features': d_features} @@ -231,7 +220,6 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, if self.encode_vit: embed = self.post_vit(additional['d_features']) - # embed = self.encoder(obs) else: embed = self.encoder(obs) embed_with_pos_enc = self.positional_augmenter_inp(embed) @@ -351,7 +339,7 @@ def KL(dist1, dist2): decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) feat_mask = self.slot_mask(masks) - d_obs = d_features.reshape(b, self.vit_feat_dim, 8, 8) + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) decoded_feats = decoded_feats * feat_mask if self.per_slot_rec_loss: @@ -364,6 +352,8 @@ def KL(dist1, dist2): d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() + d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) metrics['loss_l2_rec'] = img_rec metrics['loss_dino_rec'] = d_rec diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 385d85d..172aecc 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -18,7 +18,8 @@ world_model: kl_free_nats: 0.00 discrete_rssm: false decode_vit: false - vit_l2_ratio: 0.8 + vit_l2_ratio: 0.5 + vit_img_size: 224 encode_vit: false predict_discount: false layer_norm: ${..layer_norm} diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index edc7788..21b9b6c 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,16 +1,22 @@ defaults: - - agent: dreamer_v2_atari - - env: atari_freeway - - training: atari - - logger: tensorboard + - agent: dreamer_v2_crafter + - env: crafter + - training: crafter + - logger: wandb - _self_ - override hydra/launcher: joblib seed: 42 device_type: cuda +agent: + world_model: + decode_vit: true + vit_img_size: 224 + vit_l2_ratio: 0.5 + logger: - message: Atari with default dreamer + message: Crafter decode vit log_grads: false training: @@ -38,4 +44,7 @@ hydra: n_jobs: 8 sweeper: params: - env.task_name: amidar,asterix,chopper_command,demon_attack,freeway,private_eye,venture,video_pinball + agent.world_model.vit_img_size: 64,224 + agent.world_model.kl_loss_scale: 0.1,1 + agent.world_model.vit_l2_ratio: 0.1,0.9 + diff --git a/rl_sandbox/utils/env.py b/rl_sandbox/utils/env.py index 101a708..f4ae78b 100644 --- a/rl_sandbox/utils/env.py +++ b/rl_sandbox/utils/env.py @@ -137,7 +137,7 @@ def __init__(self, task_name: str, obs_res: tuple[int, int], sticky: bool, life_ import gym.envs.atari super().__init__(True, obs_res, repeat_action_num, transforms) - self.env: gym.Env = gym.envs.atari.AtariEnv(game=task_name, obs_type='image', frameskip=1, repeat_action_probability=0.25 if sticky else 0, full_action_space=False) + self.env: gym.Env = gym.envs.atari.AtariEnv(game=task_name, obs_type='rgb', frameskip=1, repeat_action_probability=0.25 if sticky else 0, full_action_space=False) # Tell wrapper that the inner env has no action repeat. self.env.spec = gym.envs.registration.EnvSpec('NoFrameskip-v0') self.env = gym.wrappers.AtariPreprocessing(self.env, From 54a31e687653b663655cfa6a0094b82f3854bc44 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Thu, 13 Jul 2023 13:03:27 +0100 Subject: [PATCH 080/106] Redo dino decoder --- rl_sandbox/agents/dreamer/vision.py | 21 +++++++++++++------ rl_sandbox/agents/dreamer/world_model.py | 7 ++++--- .../agents/dreamer/world_model_slots.py | 13 ++++++------ .../dreamer/world_model_slots_attention.py | 16 +++++++------- .../dreamer/world_model_slots_combined.py | 16 +++++++------- rl_sandbox/agents/dreamer_v2.py | 3 --- rl_sandbox/utils/replay_buffer.py | 7 ++----- 7 files changed, 44 insertions(+), 39 deletions(-) diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index 00a5d76..cd11d73 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -41,39 +41,48 @@ def __init__(self, kernel_sizes=[5, 5, 6, 6], channel_step = 48, output_channels=3, + conv_kernel_sizes=[], return_dist=True): super().__init__() layers = [] self.channel_step = channel_step - # 2**(len(kernel_sizes)-1)*channel_step - self.convin = nn.Linear(input_size, 32 * self.channel_step) + self.in_channels = 2 **(len(kernel_sizes)-1) * self.channel_step + in_channels = self.in_channels + self.convin = nn.Linear(input_size, in_channels) self.return_dist = return_dist - in_channels = 32 * self.channel_step #2**(len(kernel_sizes) - 1) * self.channel_step for i, k in enumerate(kernel_sizes): out_channels = 2**(len(kernel_sizes) - i - 2) * self.channel_step if i == len(kernel_sizes) - 1: - out_channels = 3 + out_channels = output_channels layers.append(nn.ConvTranspose2d(in_channels, output_channels, kernel_size=k, stride=2, output_padding=0)) else: - layers.append(norm_layer(1, in_channels)) layers.append( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=2, output_padding=0)) + layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) + for k in conv_kernel_sizes: + layers.append( + nn.Conv2d(out_channels, + out_channels, + kernel_size=k, + padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) in_channels = out_channels self.net = nn.Sequential(*layers) def forward(self, X): x = self.convin(X) - x = x.view(-1, 32 * self.channel_step, 1, 1) + x = x.view(-1, self.in_channels, 1, 1) if self.return_dist: return td.Independent(td.Normal(self.net(x), 1.0), 3) else: diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index e2396f6..2f0307a 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -88,7 +88,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - channel_step=192, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim, return_dist=True) @@ -214,14 +215,14 @@ def KL(dist1, dist2, free_nat = True): x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) img_rec = -x_r.log_prob(obs).float().mean() else: - img_rec = 0 + img_rec = torch.tensor(0, device=obs.device) x_r_detached = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1).detach()) losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() d_pred = self.dino_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) d_rec = -d_pred.log_prob(d_obs).float().mean() - d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) metrics['loss_l2_rec'] = img_rec diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 38eb654..0081327 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -106,11 +106,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - channel_step=192, - kernel_sizes=self.decoder_kernels, - output_channels=self.vit_feat_dim+1, - return_dist=False) + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -313,7 +314,7 @@ def KL(dist1, dist2): d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) d_rec = -d_pred.log_prob(d_obs).float().mean() - d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) metrics['loss_l2_rec'] = img_rec diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 5bdfbde..9d43cf8 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -115,12 +115,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - channel_step=192, - # kernel_sizes=[5, 5, 4], # for size 224x224 - kernel_sizes=[3, 4], - output_channels=self.vit_feat_dim+1, - return_dist=False) + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -349,7 +349,7 @@ def KL(dist1, dist2): losses['loss_reconstruction_img'] = img_rec_detached - decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) feat_mask = self.slot_mask(masks) d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) @@ -365,7 +365,7 @@ def KL(dist1, dist2): d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() - d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) metrics['loss_l2_rec'] = img_rec metrics['loss_dino_rec'] = d_rec diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 7d34f41..4ea5ab0 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -110,12 +110,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - channel_step=192, - # kernel_sizes=[5, 5, 4], # for size 224x224 - kernel_sizes=[3, 4], - output_channels=self.vit_feat_dim+1, - return_dist=False) + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, @@ -336,7 +336,7 @@ def KL(dist1, dist2): losses['loss_reconstruction_img'] = img_rec_detached - decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, 8, 8).split([self.vit_feat_dim, 1], dim=2) + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) feat_mask = self.slot_mask(masks) d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) @@ -352,7 +352,7 @@ def KL(dist1, dist2): d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() - d_rec = d_rec / torch.prod(torch.Tensor(d_obs.shape[-3:])) * torch.prod(torch.Tensor(obs.shape[-3:])) + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) metrics['loss_l2_rec'] = img_rec diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index e051c27..c63331c 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -160,9 +160,6 @@ def from_np(self, arr: np.ndarray): def train(self, rollout_chunks: RolloutChunks): obs, a, r, is_finished, is_first, additional = unpack(rollout_chunks) - if torch.cuda.is_available(): - torch.cuda.current_stream().synchronize() - # obs = self.preprocess_obs(self.from_np(obs)) if self.is_discrete: a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() discount_factors = (1 - is_finished).float() diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 0b8d541..e452d80 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -39,7 +39,7 @@ class Rollout: def __len__(self): return len(self.obs) - def to(self, device: str, non_blocking: bool = True): + def to(self, device: str, non_blocking: bool = False): self.obs = self.obs.to(device, non_blocking=True) self.actions = self.actions.to(device, non_blocking=True) self.rewards = self.rewards.to(device, non_blocking=True) @@ -158,11 +158,8 @@ def sample( is_finished=torch.cat(t), is_first=torch.cat(is_first), additional_data={k: torch.cat(v) for k,v in additional.items()} - ).to(self.device, non_blocking=True) + ).to(self.device, non_blocking=False) # TODO: -# [X] Rewrite to use only torch containers -# [X] Add preprocessing step on adding to replay buffer -# [X] Add possibility to store additional auxilary data (dino encodings) # [ ] (Optional) Utilize torch's dataloader for async sampling From 9d169ba3d1a636496152123d291d71060bf27b0d Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sat, 15 Jul 2023 14:02:29 +0100 Subject: [PATCH 081/106] Changed prev_slots to use same random vector --- .../agents/dreamer/world_model_slots.py | 10 ++++---- .../dreamer/world_model_slots_attention.py | 4 +-- .../dreamer/world_model_slots_combined.py | 6 ++--- rl_sandbox/agents/dreamer_v2.py | 2 +- .../agent/dreamer_v2_crafter_slotted.yaml | 2 +- rl_sandbox/config/config.yaml | 22 ++++++++-------- rl_sandbox/config/config_attention.yaml | 12 ++++++++- rl_sandbox/config/config_combined.yaml | 16 ++++++------ rl_sandbox/config/training/atari.yaml | 2 +- rl_sandbox/metrics.py | 6 +++-- rl_sandbox/vision/slot_attention.py | 25 ++++++++++++++----- 11 files changed, 68 insertions(+), 39 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 0081327..e2f643e 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -45,7 +45,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.vit_l2_ratio = vit_l2_ratio self.vit_img_size = vit_img_size - self.n_dim = 384 + self.n_dim = 192 self.recurrent_model = RSSM( latent_dim, @@ -90,11 +90,11 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], - channel_step=96, + channel_step=48, double_conv=True, flatten_output=False) - self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -265,7 +265,7 @@ def KL(dist1, dist2): slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) # FIXME: prev_slots was not used properly, need to rerun test if self.use_prev_slots: - prev_slots = slots_t + prev_slots = self.slot_attention.prev_slots else: prev_slots = None @@ -308,7 +308,7 @@ def KL(dist1, dist2): x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) losses['loss_reconstruction_img'] = -x_r_detached.log_prob(obs).float().mean() - decoded_feats, masks = d_features.reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) + decoded_feats, masks = self.dino_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=2) feat_mask = self.slot_mask(masks) decoded_feats = decoded_feats * feat_mask d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 9d43cf8..779c52e 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -102,7 +102,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, double_conv=True, flatten_output=False) - self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -277,7 +277,7 @@ def KL(dist1, dist2): slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) # FIXME: prev_slots was not used properly, need to rerun test if self.use_prev_slots: - prev_slots = slots_t + prev_slots = self.slot_attention.prev_slots else: prev_slots = None diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 4ea5ab0..ed01742 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -48,7 +48,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.vit_img_size = vit_img_size self.per_slot_rec_loss = per_slot_rec_loss - self.n_dim = 384 + self.n_dim = 192 self.recurrent_model = RSSM( latent_dim, @@ -98,7 +98,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, double_conv=True, flatten_output=False) - self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num) + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -269,7 +269,7 @@ def KL(dist1, dist2): slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) # FIXME: prev_slots was not used properly, need to rerun test if self.use_prev_slots: - prev_slots = slots_t + prev_slots = self.slot_attention.prev_slots else: prev_slots = None diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index c63331c..a1f051e 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -162,7 +162,7 @@ def train(self, rollout_chunks: RolloutChunks): obs, a, r, is_finished, is_first, additional = unpack(rollout_chunks) if self.is_discrete: a = F.one_hot(a.to(torch.int64), num_classes=self.actions_num).squeeze() - discount_factors = (1 - is_finished).float() + discount_factors = self.critic.gamma*(1 - is_finished).float() first_flags = is_first.float() # take some latent embeddings as initial diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml index 2ca3961..a34d9cb 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -9,6 +9,6 @@ world_model: slots_iter_num: 2 kl_loss_scale: 1e2 decode_vit: true - use_prev_slots: false + use_prev_slots: true vit_l2_ratio: 0.1 encode_vit: false diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 21b9b6c..44601da 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -14,9 +14,10 @@ agent: decode_vit: true vit_img_size: 224 vit_l2_ratio: 0.5 + kl_loss_scale: 1.0 logger: - message: Crafter decode vit + message: New decoder log_grads: false training: @@ -38,13 +39,14 @@ debug: profiler: false hydra: - mode: MULTIRUN - #mode: RUN - launcher: - n_jobs: 8 - sweeper: - params: - agent.world_model.vit_img_size: 64,224 - agent.world_model.kl_loss_scale: 0.1,1 - agent.world_model.vit_l2_ratio: 0.1,0.9 + #mode: MULTIRUN + mode: RUN + #launcher: + # n_jobs: 1 + #sweeper: + #params: + # agent.world_model.vit_img_size: 224 + # agent.world_model.kl_loss_scale: 1 + # agent.world_model.kl_free_nats: 0.0,1.0 + # agent.world_model.vit_l2_ratio: 0.5 diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 7cd7b5b..ee12092 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -9,8 +9,18 @@ defaults: seed: 42 device_type: cuda +agent: + world_model: + encode_vit: false + decode_vit: false + #vit_img_size: 224 + #vit_l2_ratio: 0.5 + slots_iter_num: 3 + slots_num: 6 + kl_loss_scale: 100.0 + logger: - message: Crafter attention 2x sched, fixed layer, dis f16 + message: Attention, without dino, kl=100, 3 iter, 192 n_dim log_grads: false training: diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index d242a40..73fa2b7 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -1,7 +1,7 @@ defaults: - agent: dreamer_v2_slotted_combined - env: crafter - - logger: tensorboard + - logger: wandb - training: crafter - _self_ - override hydra/launcher: joblib @@ -11,18 +11,20 @@ device_type: cuda agent: world_model: - #encode_vit: true + encode_vit: false decode_vit: false - #vit_l2_ratio: 1.0 - kl_loss_scale: 1e2 + #vit_img_size: 224 + #vit_l2_ratio: 0.5 + slots_iter_num: 3 + slots_num: 6 + kl_loss_scale: 100.0 logger: - message: Combined, without dino + message: Combined, without dino, kl=100, 3 iter, 192 n_dim log_grads: false training: checkpoint_path: null - steps: 1e6 val_logs_every: 2e4 validation: @@ -45,8 +47,8 @@ hydra: n_jobs: 1 #sweeper: # params: + # agent.world_model.kl_loss_scale: 1e2,1e1 # agent.world_model.slots_num: 3,6 # agent.world_model.per_slot_rec_loss: true # agent.world_model.mask_combination: soft,hard - # agent.world_model.kl_loss_scale: 1e2 # agent.world_model.vit_l2_ratio: 0.1,1e-3 diff --git a/rl_sandbox/config/training/atari.yaml b/rl_sandbox/config/training/atari.yaml index 272db19..1aa0a6f 100644 --- a/rl_sandbox/config/training/atari.yaml +++ b/rl_sandbox/config/training/atari.yaml @@ -4,6 +4,6 @@ batch_size: 16 f16_precision: false pretrain: 1 prioritize_ends: true -train_every: 4 +train_every: 16 save_checkpoint_every: 5e5 val_logs_every: 2e4 diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index b17aced..3c38009 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -270,6 +270,8 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ vit_slots_video = [] rews = [] + vit_size = self.agent.world_model.vit_size + state = None prev_slots = None for idx, (o, a) in enumerate(list(zip(obs, actions))): @@ -284,7 +286,7 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) - _, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + _, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) upscale = tv.transforms.Resize(64, antialias=True) @@ -308,7 +310,7 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) - _, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, 8, 8).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + _, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) upscale = tv.transforms.Resize(64, antialias=True) upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index ebde600..12e1e2a 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -11,7 +11,7 @@ from rl_sandbox.utils.logger import Logger class SlotAttention(nn.Module): - def __init__(self, num_slots: int, n_dim: int, n_iter: int): + def __init__(self, num_slots: int, n_dim: int, n_iter: int, use_prev_slots: bool): super().__init__() self.n_slots = num_slots @@ -20,15 +20,20 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int): self.scale = self.n_dim**(-1/2) self.epsilon = 1e-8 - self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) - self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) + self.use_prev_slots = use_prev_slots + if use_prev_slots: + self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + else: + self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) nn.init.xavier_uniform_(self.slots_logsigma) self.slots_proj = nn.Linear(n_dim, n_dim) self.slots_proj_2 = nn.Sequential( - nn.Linear(n_dim, n_dim*4), + nn.Linear(n_dim, n_dim*2), nn.ReLU(inplace=True), - nn.Linear(n_dim*4, n_dim), + nn.Linear(n_dim*2, n_dim), ) self.slots_norm = nn.LayerNorm(self.n_dim) self.slots_norm_2 = nn.LayerNorm(self.n_dim) @@ -36,16 +41,22 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int): self.inputs_proj = nn.Linear(n_dim, n_dim*2) self.inputs_norm = nn.LayerNorm(self.n_dim) + self.prev_slots = None def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']]) -> Float[torch.Tensor, 'batch num_slots n_dim']: batch, _, _ = X.shape k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) if prev_slots is None: - slots = self.slots_mu + self.slots_logsigma.exp() * torch.randn(batch, self.n_slots, self.n_dim, device=X.device) + mu = self.slots_mu.expand(batch, self.n_slots, -1) + sigma = self.slots_logsigma.exp().expand(batch, self.n_slots, -1) + slots = mu + sigma * torch.randn(mu.shape, device=X.device) + self.prev_slots = slots.clone() else: slots = prev_slots + self.last_attention = None + for _ in range(self.n_iter): slots_prev = slots slots = self.slots_norm(slots) @@ -54,6 +65,8 @@ def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optio attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k), dim=1) + self.epsilon attn = attn / attn.sum(dim=-1, keepdim=True) + self.last_attention = attn + updates = torch.einsum('bij,bjk->bik', attn, v) / self.n_slots slots = self.slots_reccur(updates.reshape(-1, self.n_dim), slots_prev.reshape(-1, self.n_dim)).reshape(batch, self.n_slots, self.n_dim) slots = slots + self.slots_proj_2(self.slots_norm_2(slots)) From c3ececaec9e0f4fe936d5756356be4854859267e Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sat, 15 Jul 2023 15:27:31 +0100 Subject: [PATCH 082/106] Rewritten slot attention to calculate out of loop --- .../dreamer/world_model_slots_attention.py | 69 +++++++++---------- .../dreamer/world_model_slots_combined.py | 29 +++----- .../agent/dreamer_v2_slotted_attention.yaml | 2 +- .../agent/dreamer_v2_slotted_combined.yaml | 4 +- rl_sandbox/config/config_attention.yaml | 4 +- rl_sandbox/train.py | 2 +- rl_sandbox/vision/slot_attention.py | 10 ++- 7 files changed, 55 insertions(+), 65 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 779c52e..2e43a18 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -50,7 +50,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.vit_img_size = vit_img_size self.per_slot_rec_loss = per_slot_rec_loss - self.n_dim = 384 + self.n_dim = 192 self.recurrent_model = RSSM( latent_dim, @@ -98,7 +98,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], - channel_step=96, + channel_step=48, double_conv=True, flatten_output=False) @@ -107,7 +107,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) - # self.positional_augmenter_dec = PositionalEmbedding(self.n_dim, (8, 8)) self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), nn.ReLU(inplace=True), @@ -164,7 +163,7 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: (0.229, 0.224, 0.225)), tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) - d_features = self.dino_vit(obs) + d_features = self.dino_vit(obs).squeeze() return {'d_features': d_features} def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): @@ -223,7 +222,6 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, self.recurrent_model.on_train_step() b, _, h, w = obs.shape # s <- BxHxWx3 - embed = self.encoder(obs) if self.encode_vit: embed = self.post_vit(additional['d_features']) else: @@ -264,23 +262,19 @@ def KL(dist1, dist2): self.last_attn = torch.zeros((self.slots_num, self.slots_num), device=a_c.device) + if self.use_prev_slots: + prev_slots = self.slot_attention.generate_initial(b // self.cluster_size).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) + else: + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1)).reshape(b, seq_num, self.slots_num, -1) + for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, - t], a_c[:, t].unsqueeze( - 0 - ), first_c[:, - t].unsqueeze( - 0) + slots_t, a_t, first_t = (slots_c[:,t], + a_c[:, t].unsqueeze(0), + first_c[:,t].unsqueeze(0)) a_t = a_t * (1 - first_t) - slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) - # FIXME: prev_slots was not used properly, need to rerun test - if self.use_prev_slots: - prev_slots = self.slot_attention.prev_slots - else: - prev_slots = None - prior, posterior, diff = self.recurrent_model.forward( prev_state, slots_t.unsqueeze(0), a_t) prev_state = posterior @@ -304,15 +298,14 @@ def KL(dist1, dist2): if not self.decode_vit: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:]))) / img_mask.sum(dim=[2, 3, 4]) + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 - decoded_imgs = decoded_imgs * img_mask + img_rec = l2_loss.mean() * normalizing_factor * 8 else: - decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() @@ -321,29 +314,28 @@ def KL(dist1, dist2): if self.vit_l2_ratio != 1.0: decoded_imgs, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2)).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask*((decoded_imgs - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 - decoded_imgs = decoded_imgs * img_mask + img_rec = l2_loss.mean() * normalizing_factor * 8 else: - decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() else: - img_rec = 0 + img_rec = torch.tensor(0, device=obs.device) decoded_imgs_detached, masks = self.image_predictor(posterior.combined_slots.transpose(0, 1).flatten(0, 2).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=2) img_mask = self.slot_mask(masks) + decoded_imgs_detached = decoded_imgs_detached * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask*((decoded_imgs_detached - obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(obs.shape[1:])))/img_mask.sum(dim=[2, 3, 4]) + l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec_detached = (l2_loss * normalizing_factor).sum(dim=1).mean() * self.slots_num * 8 + img_rec_detached = l2_loss.mean() * normalizing_factor * 8 else: - decoded_imgs_detached = decoded_imgs_detached * img_mask x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=1), 1.0), 3) img_rec_detached = -x_r_detached.log_prob(obs).float().mean() @@ -354,18 +346,19 @@ def KL(dist1, dist2): d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + decoded_feats = decoded_feats * feat_mask if self.per_slot_rec_loss: - l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).mean(dim=[2, 3, 4]) - normalizing_factor = (torch.prod(torch.tensor(d_obs.shape[1:])))/feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) # # magic constant that describes the difference between log_prob and mse losses - d_rec = (l2_loss * normalizing_factor).sum(dim=1).mean()*self.slots_num * 4 - decoded_feats = decoded_feats * feat_mask + d_rec = l2_loss.mean() * normalizing_factor * 4 else: - decoded_feats = decoded_feats * feat_mask d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=1), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) metrics['loss_l2_rec'] = img_rec metrics['loss_dino_rec'] = d_rec diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index ed01742..a5d58f5 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -1,6 +1,5 @@ import typing as t -import math import torch import torch.distributions as td import torchvision as tv @@ -9,7 +8,7 @@ from rl_sandbox.agents.dreamer import Dist, Normalizer, View from rl_sandbox.agents.dreamer.rssm_slots_combined import RSSM, State -from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, ViTDecoder +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder from rl_sandbox.utils.dists import DistLayer from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.vision.dino import ViTFeat @@ -94,7 +93,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], - channel_step=96, + channel_step=48, double_conv=True, flatten_output=False) @@ -223,7 +222,6 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, else: embed = self.encoder(obs) embed_with_pos_enc = self.positional_augmenter_inp(embed) - # embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) pre_slot_features = self.slot_mlp( embed_with_pos_enc.permute(0, 2, 3, 1).reshape(b, -1, self.n_dim)) @@ -256,23 +254,19 @@ def KL(dist1, dist2): d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) + if self.use_prev_slots: + prev_slots = self.slot_attention.generate_initial(b // self.cluster_size).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1).transpose(0, 1) + else: + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1)).reshape(b, seq_num, self.slots_num, -1).transpose(0, 1) + for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 - pre_slot_feature_t, a_t, first_t = pre_slot_features_c[:, - t], a_c[:, t].unsqueeze( - 0 - ), first_c[:, - t].unsqueeze( - 0) + slots_t, a_t, first_t = (slots_c[:,t], + a_c[:, t].unsqueeze(0), + first_c[:,t].unsqueeze(0)) a_t = a_t * (1 - first_t) - slots_t = self.slot_attention(pre_slot_feature_t, prev_slots) - # FIXME: prev_slots was not used properly, need to rerun test - if self.use_prev_slots: - prev_slots = self.slot_attention.prev_slots - else: - prev_slots = None - prior, posterior, diff = self.recurrent_model.forward( prev_state, slots_t.unsqueeze(0), a_t) prev_state = posterior @@ -375,4 +369,3 @@ def KL(dist1, dist2): return losses, posterior, metrics - diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index a0c184c..8451b71 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -11,7 +11,7 @@ world_model: encode_vit: false decode_vit: true mask_combination: soft - use_prev_slots: false + use_prev_slots: true per_slot_rec_loss: false vit_l2_ratio: 0.1 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index a6c2bef..82bc48c 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -5,12 +5,12 @@ defaults: world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel rssm_dim: 512 - slots_num: 6 + slots_num: 5 slots_iter_num: 2 kl_loss_scale: 1e2 encode_vit: false decode_vit: true mask_combination: soft - use_prev_slots: false + use_prev_slots: true per_slot_rec_loss: false vit_l2_ratio: 0.1 diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index ee12092..d389777 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda +device_type: cuda:1 agent: world_model: @@ -20,7 +20,7 @@ agent: kl_loss_scale: 100.0 logger: - message: Attention, without dino, kl=100, 3 iter, 192 n_dim + message: Attention, without dino, prev_slots, optimized inference log_grads: false training: diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 0a9ad7e..596b4f1 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -1,7 +1,7 @@ import random import os os.environ['MUJOCO_GL'] = 'egl' -os.environ["WANDB_MODE"]="offline" +# os.environ["WANDB_MODE"]="offline" import crafter import hydra diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index 12e1e2a..a23fec6 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -43,14 +43,18 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int, use_prev_slots: bool self.inputs_norm = nn.LayerNorm(self.n_dim) self.prev_slots = None + def generate_initial(self, batch: int): + mu = self.slots_mu.expand(batch, self.n_slots, -1) + sigma = self.slots_logsigma.exp().expand(batch, self.n_slots, -1) + slots = mu + sigma * torch.randn(mu.shape, device=mu.device) + return slots + def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optional[Float[torch.Tensor, 'batch num_slots n_dim']]) -> Float[torch.Tensor, 'batch num_slots n_dim']: batch, _, _ = X.shape k, v = self.inputs_proj(self.inputs_norm(X)).chunk(2, dim=-1) if prev_slots is None: - mu = self.slots_mu.expand(batch, self.n_slots, -1) - sigma = self.slots_logsigma.exp().expand(batch, self.n_slots, -1) - slots = mu + sigma * torch.randn(mu.shape, device=X.device) + slots = self.generate_initial(batch) self.prev_slots = slots.clone() else: slots = prev_slots From 4cfaed3433201bcbc4cf4e803280891f00bf6da2 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 16 Jul 2023 16:06:26 +0100 Subject: [PATCH 083/106] Fix KL div calculation --- rl_sandbox/agents/dreamer/common.py | 3 ++- rl_sandbox/agents/dreamer/world_model.py | 4 ++-- .../agents/dreamer/world_model_slots.py | 8 ++++---- .../dreamer/world_model_slots_attention.py | 19 ++++++++----------- .../dreamer/world_model_slots_combined.py | 6 +++--- .../agent/dreamer_v2_slotted_attention.yaml | 4 ++-- .../agent/dreamer_v2_slotted_combined.yaml | 4 ++-- rl_sandbox/config/config.yaml | 5 +++-- rl_sandbox/config/config_attention.yaml | 8 +++++--- rl_sandbox/config/config_combined.yaml | 7 ++++--- rl_sandbox/metrics.py | 4 ++-- 11 files changed, 37 insertions(+), 35 deletions(-) diff --git a/rl_sandbox/agents/dreamer/common.py b/rl_sandbox/agents/dreamer/common.py index 84aab92..36ecf4a 100644 --- a/rl_sandbox/agents/dreamer/common.py +++ b/rl_sandbox/agents/dreamer/common.py @@ -1,5 +1,6 @@ import torch from torch import nn +import torch.distributions as td from rl_sandbox.utils.dists import DistLayer @@ -15,7 +16,7 @@ def forward(self, x): def Dist(val): - return DistLayer('onehot')(val) + return td.Independent(DistLayer('onehot')(val), 1) class Normalizer(nn.Module): diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 2f0307a..ef19d0f 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -127,8 +127,8 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): device = next(self.parameters()).device return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), - torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), - torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) + torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index e2f643e..99a15c3 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -236,11 +236,11 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence - kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), - td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_lhs = KL_(td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1), 1)).mean() kl_rhs = KL_( - td.OneHotCategoricalStraightThrough(logits=dist2), - td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1.detach()), 1)).mean() kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 2e43a18..3428cab 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -115,8 +115,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3], - channel_step=2*self.vit_feat_dim, + conv_kernel_sizes=[], + channel_step=self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim+1, return_dist=False) @@ -243,11 +243,11 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence - kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), - td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_lhs = KL_(td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1), 1)).mean() kl_rhs = KL_( - td.OneHotCategoricalStraightThrough(logits=dist2), - td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1.detach()), 1)).mean() kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) @@ -262,11 +262,8 @@ def KL(dist1, dist2): self.last_attn = torch.zeros((self.slots_num, self.slots_num), device=a_c.device) - if self.use_prev_slots: - prev_slots = self.slot_attention.generate_initial(b // self.cluster_size).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) - slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) - else: - slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1)).reshape(b, seq_num, self.slots_num, -1) + prev_slots = self.slot_attention.generate_initial(b // self.cluster_size).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index a5d58f5..b71c1f3 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -110,8 +110,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3], - channel_step=2*self.vit_feat_dim, + conv_kernel_sizes=[], + channel_step=self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim+1, return_dist=False) @@ -190,6 +190,7 @@ def predict_next(self, prev_state: State, action): discount_factors = self.discount_predictor(prior.combined).sample() else: discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: @@ -297,7 +298,6 @@ def KL(dist1, dist2): else: x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=1), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() - losses['loss_reconstruction'] = img_rec else: if self.vit_l2_ratio != 1.0: diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index 8451b71..8402bf2 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -7,11 +7,11 @@ world_model: rssm_dim: 512 slots_num: 6 slots_iter_num: 2 - kl_loss_scale: 1e2 + kl_loss_scale: 1.0 encode_vit: false decode_vit: true mask_combination: soft - use_prev_slots: true + use_prev_slots: false per_slot_rec_loss: false vit_l2_ratio: 0.1 diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index 82bc48c..546b68e 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -7,10 +7,10 @@ world_model: rssm_dim: 512 slots_num: 5 slots_iter_num: 2 - kl_loss_scale: 1e2 + kl_loss_scale: 1.0 encode_vit: false decode_vit: true mask_combination: soft - use_prev_slots: true + use_prev_slots: false per_slot_rec_loss: false vit_l2_ratio: 0.1 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 44601da..bf7f8b3 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda +device_type: cuda:1 agent: world_model: @@ -15,9 +15,10 @@ agent: vit_img_size: 224 vit_l2_ratio: 0.5 kl_loss_scale: 1.0 + kl_free_nats: 1.0 logger: - message: New decoder + message: New decoder, fixed KL log_grads: false training: diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index d389777..5ed361b 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:1 +device_type: cuda:0 agent: world_model: @@ -17,10 +17,11 @@ agent: #vit_l2_ratio: 0.5 slots_iter_num: 3 slots_num: 6 - kl_loss_scale: 100.0 + kl_loss_scale: 1.0 + kl_free_nats: 1.0 logger: - message: Attention, without dino, prev_slots, optimized inference + message: Attention, without dino, fixed KL, 1 nat, no prev slot log_grads: false training: @@ -35,6 +36,7 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true + #- _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator _partial_: true diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index 73fa2b7..03dbb13 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda +device_type: cuda:0 agent: world_model: @@ -17,10 +17,11 @@ agent: #vit_l2_ratio: 0.5 slots_iter_num: 3 slots_num: 6 - kl_loss_scale: 100.0 + kl_loss_scale: 1.0 + kl_free_nats: 1.0 logger: - message: Combined, without dino, kl=100, 3 iter, 192 n_dim + message: Combined, without dino, fixed KL, 1 nat, no prev slot log_grads: false training: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 3c38009..682e53e 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -63,7 +63,7 @@ def on_step(self, logger): if self.agent.is_discrete: self._action_probs += self._action_probs - self._latent_probs += self.agent._state.stoch_dist.probs.squeeze().mean(dim=0) + self._latent_probs += self.agent._state.stoch_dist.base_dist.probs.squeeze().mean(dim=0) def on_episode(self, logger): latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() @@ -152,7 +152,7 @@ def on_step(self, logger): if self.agent.is_discrete: self._action_probs += self._action_probs - self._latent_probs += self.agent._state[0].stoch_dist.probs.squeeze().mean(dim=0) + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze().mean(dim=0) def on_episode(self, logger): wm = self.agent.world_model From a7ff93bac1456682dfd84705ab7cb5fa320f4c74 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 16 Jul 2023 15:07:32 +0000 Subject: [PATCH 084/106] Added pos encoding for combined slot dreamer --- rl_sandbox/agents/dreamer/rssm.py | 11 ++++++ rl_sandbox/agents/dreamer/rssm_slots.py | 11 ++++++ .../agents/dreamer/rssm_slots_attention.py | 11 ++++++ .../agents/dreamer/rssm_slots_combined.py | 34 +++++++++++++++---- .../dreamer/world_model_slots_combined.py | 28 +++++++++------ rl_sandbox/agents/dreamer_v2.py | 4 +-- rl_sandbox/config/config_combined.yaml | 2 +- rl_sandbox/vision/slot_attention.py | 12 +++---- 8 files changed, 86 insertions(+), 27 deletions(-) diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py index d99816b..c28e046 100644 --- a/rl_sandbox/agents/dreamer/rssm.py +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -15,6 +15,17 @@ class State: stoch_logits: Float[torch.Tensor, 'seq batch latent_classes latent_dim'] stoch_: t.Optional[Bool[torch.Tensor, 'seq batch stoch_dim']] = None + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None) + + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None) + @property def combined(self): return torch.concat([self.determ, self.stoch], dim=-1) diff --git a/rl_sandbox/agents/dreamer/rssm_slots.py b/rl_sandbox/agents/dreamer/rssm_slots.py index cbbe403..24ca200 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots.py +++ b/rl_sandbox/agents/dreamer/rssm_slots.py @@ -14,6 +14,17 @@ class State: stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None) + + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None) + @property def combined(self): return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index 658cd0e..c5b61bd 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -16,6 +16,17 @@ class State: stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None) + + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None) + @property def combined(self): return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py index c471c44..8fee5ee 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_combined.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -13,14 +13,31 @@ class State: determ: Float[torch.Tensor, 'seq batch num_slots determ'] stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + pos_enc: t.Optional[Float[torch.Tensor, '1 1 num_slots stoch_dim+determ']] = None + + def flatten(self): + return State(self.determ.flatten(0, 1).unsqueeze(0), + self.stoch_logits.flatten(0, 1).unsqueeze(0), + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None, + self.pos_enc.detach() if self.pos_enc is not None else None) + + def detach(self): + return State(self.determ.detach(), + self.stoch_logits.detach(), + self.stoch_.detach() if self.stoch_ is not None else None, + self.pos_enc.detach() if self.pos_enc is not None else None) @property def combined(self): - return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + return self.combined_slots.flatten(2, 3) @property def combined_slots(self): - return torch.concat([self.determ, self.stoch], dim=-1) + state = torch.concat([self.determ, self.stoch], dim=-1) + if self.pos_enc is not None: + return state + self.pos_enc + else: + return state @property def stoch(self): @@ -40,7 +57,9 @@ def stack(cls, states: list['State'], dim=0): else: stochs = None return State(torch.cat([state.determ for state in states], dim=dim), - torch.cat([state.stoch_logits for state in states], dim=dim), stochs) + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs, + states[0].pos_enc) class GRUCell(nn.Module): @@ -164,7 +183,7 @@ def estimate_stochastic_latent(self, prev_determ: torch.Tensor): def predict_next(self, prev_state: State, action) -> State: x = self.pre_determ_recurrent( torch.concat([ - prev_state.stoch, + prev_state.stoch + prev_state.pos_enc[:, :, :, -prev_state.stoch.shape[-1]:], action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) ], dim=-1)) @@ -178,16 +197,17 @@ def predict_next(self, prev_state: State, action) -> State: # used for KL divergence # TODO: Test both options (with slot in batch size and in feature dim) - predicted_stoch_logits = self.estimate_stochastic_latent(x.reshape(prev_state.determ.shape)) + predicted_stoch_logits = self.estimate_stochastic_latent(x.reshape(prev_state.determ.shape) + prev_state.pos_enc[:, :, :, :-prev_state.stoch.shape[-1]]) # Size is 1 x B x slots_num x ... return State(determ_post.reshape(prev_state.determ.shape), - predicted_stoch_logits.reshape(prev_state.stoch_logits.shape)), diff + predicted_stoch_logits.reshape(prev_state.stoch_logits.shape), + pos_enc=prev_state.pos_enc), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' return State( prior.determ, self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( - 1, 2).reshape(prior.stoch_logits.shape)) + 1, 2).reshape(prior.stoch_logits.shape), pos_enc=prior.pos_enc) def forward(self, h_prev: State, embed, action) -> tuple[State, State]: """ diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index b71c1f3..e66f448 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -98,6 +98,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) + self.state_emb = nn.Embedding(slots_num, self.state_size // slots_num) + self.slot_emb = nn.Embedding(slots_num, self.n_dim) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -136,6 +138,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, layer_norm=layer_norm, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + self.slot_indexer = torch.linspace(0, + self.slots_num-1, + self.slots_num, + dtype=torch.long) def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: match self.mask_combination: @@ -180,7 +186,8 @@ def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): batch_size, self.slots_num, self.latent_classes * self.latent_dim, - device=device)), None + device=device), + self.state_emb(self.slot_indexer).unsqueeze(0).unsqueeze(0)), None def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) @@ -211,6 +218,8 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) + + pos_enc = self.state_emb(torch.linspace(0, self.slots_num-1, self.slots_num, dtype=torch.long)).unsqueeze(0).unsqueeze(0) return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, @@ -239,11 +248,11 @@ def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, def KL(dist1, dist2): KL_ = torch.distributions.kl_divergence - kl_lhs = KL_(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), - td.OneHotCategoricalStraightThrough(logits=dist1)).mean() + kl_lhs = KL_(td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2.detach()), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1), 1)).mean() kl_rhs = KL_( - td.OneHotCategoricalStraightThrough(logits=dist2), - td.OneHotCategoricalStraightThrough(logits=dist1.detach())).mean() + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist2), 1), + td.Independent(td.OneHotCategoricalStraightThrough(logits=dist1.detach()), 1)).mean() kl_lhs = torch.maximum(kl_lhs, self.kl_free_nats) kl_rhs = torch.maximum(kl_rhs, self.kl_free_nats) return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) @@ -255,11 +264,10 @@ def KL(dist1, dist2): d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) - if self.use_prev_slots: - prev_slots = self.slot_attention.generate_initial(b // self.cluster_size).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) - slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1).transpose(0, 1) - else: - slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1)).reshape(b, seq_num, self.slots_num, -1).transpose(0, 1) + slot_pos_enc = self.slot_emb(self.slot_indexer).unsqueeze(0) + prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size) + slot_pos_enc).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) + slots_c = slots_c + slot_pos_enc.unsqueeze(0) for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index a1f051e..7dcc9bd 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -178,9 +178,7 @@ def train(self, rollout_chunks: RolloutChunks): metrics_wm |= self.world_model_optimizer.step(losses_wm['loss_wm']) with torch.cuda.amp.autocast(enabled=self.is_f16): - initial_states = discovered_states.__class__(discovered_states.determ.flatten(0, 1).unsqueeze(0).detach(), - discovered_states.stoch_logits.flatten(0, 1).unsqueeze(0).detach(), - discovered_states.stoch_.flatten(0, 1).unsqueeze(0).detach()) + initial_states = discovered_states.flatten().detach() states, actions, rewards, discount_factors = self.imagine_trajectory(initial_states) zs = states.combined diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index 03dbb13..d355736 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -21,7 +21,7 @@ agent: kl_free_nats: 1.0 logger: - message: Combined, without dino, fixed KL, 1 nat, no prev slot + message: Combined, without dino, added pos encoding log_grads: false training: diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index a23fec6..f80cd37 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -21,12 +21,12 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int, use_prev_slots: bool self.epsilon = 1e-8 self.use_prev_slots = use_prev_slots - if use_prev_slots: - self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) - self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) - else: - self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) - self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) + # if use_prev_slots: + self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + # else: + # self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) + # self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) nn.init.xavier_uniform_(self.slots_logsigma) self.slots_proj = nn.Linear(n_dim, n_dim) From e955bc7d40b31e6d4f451e6c959eb1b590fa5ac8 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 16 Jul 2023 16:35:15 +0100 Subject: [PATCH 085/106] fixup! Added pos encoding for combined slot dreamer --- rl_sandbox/agents/dreamer/world_model_slots_combined.py | 5 ++--- rl_sandbox/config/config_combined.yaml | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index e66f448..0460f10 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -138,10 +138,10 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, layer_norm=layer_norm, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) - self.slot_indexer = torch.linspace(0, + self.register_buffer('slot_indexer', torch.linspace(0, self.slots_num-1, self.slots_num, - dtype=torch.long) + dtype=torch.long)) def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: match self.mask_combination: @@ -219,7 +219,6 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) - pos_enc = self.state_emb(torch.linspace(0, self.slots_num-1, self.slots_num, dtype=torch.long)).unsqueeze(0).unsqueeze(0) return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index d355736..e273c3a 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:0 +device_type: cuda:1 agent: world_model: From f5d6413908faaf5630e22e25e2b9c502064d2447 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Mon, 17 Jul 2023 10:31:25 +0100 Subject: [PATCH 086/106] Changed pos enc to sin-based --- .../agents/dreamer/rssm_slots_combined.py | 4 ++-- .../dreamer/world_model_slots_combined.py | 24 ++++++++++++++----- rl_sandbox/config/config_combined.yaml | 2 +- rl_sandbox/vision/slot_attention.py | 12 +++++----- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py index 8fee5ee..5716190 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_combined.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -183,7 +183,7 @@ def estimate_stochastic_latent(self, prev_determ: torch.Tensor): def predict_next(self, prev_state: State, action) -> State: x = self.pre_determ_recurrent( torch.concat([ - prev_state.stoch + prev_state.pos_enc[:, :, :, -prev_state.stoch.shape[-1]:], + prev_state.stoch, action.unsqueeze(2).repeat((1, 1, prev_state.determ.shape[2], 1)) ], dim=-1)) @@ -197,7 +197,7 @@ def predict_next(self, prev_state: State, action) -> State: # used for KL divergence # TODO: Test both options (with slot in batch size and in feature dim) - predicted_stoch_logits = self.estimate_stochastic_latent(x.reshape(prev_state.determ.shape) + prev_state.pos_enc[:, :, :, :-prev_state.stoch.shape[-1]]) + predicted_stoch_logits = self.estimate_stochastic_latent(x.reshape(prev_state.determ.shape)) # Size is 1 x B x slots_num x ... return State(determ_post.reshape(prev_state.determ.shape), predicted_stoch_logits.reshape(prev_state.stoch_logits.shape), diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 0460f10..33576bf 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -98,8 +98,18 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) - self.state_emb = nn.Embedding(slots_num, self.state_size // slots_num) - self.slot_emb = nn.Embedding(slots_num, self.n_dim) + def getPositionEncoding(seq_len, d, n=10000): + import numpy as np + P = np.zeros((seq_len, d)) + for k in range(seq_len): + for i in np.arange(int(d/2)): + denominator = np.power(n, 2*i/d) + P[k, 2*i] = np.sin(k/denominator) + P[k, 2*i+1] = np.cos(k/denominator) + return P + self.register_buffer('pos_enc', torch.from_numpy(getPositionEncoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) + # self.state_emb = nn.Embedding(slots_num, self.state_size // slots_num) + # self.slot_emb = nn.Embedding(slots_num, self.n_dim) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -187,7 +197,8 @@ def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): self.slots_num, self.latent_classes * self.latent_dim, device=device), - self.state_emb(self.slot_indexer).unsqueeze(0).unsqueeze(0)), None + self.pos_enc.unsqueeze(0).unsqueeze(0)), None + # self.state_emb(self.slot_indexer).unsqueeze(0).unsqueeze(0)), None def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) @@ -263,10 +274,11 @@ def KL(dist1, dist2): d_features = additional['d_features'] prev_state, prev_slots = self.get_initial_state(b // self.cluster_size) - slot_pos_enc = self.slot_emb(self.slot_indexer).unsqueeze(0) - prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size) + slot_pos_enc).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + # slot_pos_enc = self.slot_emb(self.slot_indexer).unsqueeze(0) + # prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size) + slot_pos_enc).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size)).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) - slots_c = slots_c + slot_pos_enc.unsqueeze(0) + # slots_c = slots_c + slot_pos_enc.unsqueeze(0) for t in range(self.cluster_size): # s_t <- 1xB^xHxWx3 diff --git a/rl_sandbox/config/config_combined.yaml b/rl_sandbox/config/config_combined.yaml index e273c3a..f874214 100644 --- a/rl_sandbox/config/config_combined.yaml +++ b/rl_sandbox/config/config_combined.yaml @@ -21,7 +21,7 @@ agent: kl_free_nats: 1.0 logger: - message: Combined, without dino, added pos encoding + message: Combined, without dino, added pos encoding for reconstruction log_grads: false training: diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index f80cd37..a23fec6 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -21,12 +21,12 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int, use_prev_slots: bool self.epsilon = 1e-8 self.use_prev_slots = use_prev_slots - # if use_prev_slots: - self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) - self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) - # else: - # self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) - # self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) + if use_prev_slots: + self.slots_mu = nn.Parameter(torch.randn(1, 1, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, self.n_dim)) + else: + self.slots_mu = nn.Parameter(torch.randn(1, num_slots, self.n_dim)) + self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) nn.init.xavier_uniform_(self.slots_logsigma) self.slots_proj = nn.Linear(n_dim, n_dim) From 9849dcafd1ea910d3432f54fae30703c1821f88c Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Mon, 17 Jul 2023 21:38:30 +0100 Subject: [PATCH 087/106] Fixed attention RSSM, added updated determ for more stable learning --- rl_sandbox/agents/dreamer/common.py | 9 +++++ .../agents/dreamer/rssm_slots_attention.py | 35 ++++++++++++------- .../dreamer/world_model_slots_attention.py | 8 +++-- .../dreamer/world_model_slots_combined.py | 22 ++---------- rl_sandbox/config/config.yaml | 26 +++++--------- rl_sandbox/config/config_attention.yaml | 4 +-- 6 files changed, 49 insertions(+), 55 deletions(-) diff --git a/rl_sandbox/agents/dreamer/common.py b/rl_sandbox/agents/dreamer/common.py index 36ecf4a..1c22064 100644 --- a/rl_sandbox/agents/dreamer/common.py +++ b/rl_sandbox/agents/dreamer/common.py @@ -1,9 +1,18 @@ import torch from torch import nn import torch.distributions as td +import numpy as np from rl_sandbox.utils.dists import DistLayer +def get_position_encoding(seq_len, d, n=10000): + P = np.zeros((seq_len, d)) + for k in range(seq_len): + for i in np.arange(int(d/2)): + denominator = np.power(n, 2*i/d) + P[k, 2*i] = np.sin(k/denominator) + P[k, 2*i+1] = np.cos(k/denominator) + return P class View(nn.Module): diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index c5b61bd..88681ee 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -15,25 +15,32 @@ class State: determ: Float[torch.Tensor, 'seq batch num_slots determ'] stoch_logits: Float[torch.Tensor, 'seq batch num_slots latent_classes latent_dim'] stoch_: t.Optional[Bool[torch.Tensor, 'seq batch num_slots stoch_dim']] = None + pos_enc: t.Optional[Float[torch.Tensor, '1 1 num_slots stoch_dim+determ']] = None + determ_updated: t.Optional[Float[torch.Tensor, 'seq batch num_slots determ']] = None def flatten(self): return State(self.determ.flatten(0, 1).unsqueeze(0), self.stoch_logits.flatten(0, 1).unsqueeze(0), - self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None) - + self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None, + self.pos_enc.detach() if self.pos_enc is not None else None) def detach(self): return State(self.determ.detach(), self.stoch_logits.detach(), - self.stoch_.detach() if self.stoch_ is not None else None) + self.stoch_.detach() if self.stoch_ is not None else None, + self.pos_enc.detach() if self.pos_enc is not None else None) @property def combined(self): - return torch.concat([self.determ, self.stoch], dim=-1).flatten(2, 3) + return self.combined_slots.flatten(2, 3) @property def combined_slots(self): - return torch.concat([self.determ, self.stoch], dim=-1) + state = torch.concat([self.determ, self.stoch], dim=-1) + if self.pos_enc is not None: + return state + self.pos_enc + else: + return state @property def stoch(self): @@ -53,7 +60,9 @@ def stack(cls, states: list['State'], dim=0): else: stochs = None return State(torch.cat([state.determ for state in states], dim=dim), - torch.cat([state.stoch_logits for state in states], dim=dim), stochs) + torch.cat([state.stoch_logits for state in states], dim=dim), + stochs, + states[0].pos_enc) class RSSM(nn.Module): @@ -169,7 +178,7 @@ def predict_next(self, prev_state: State, action) -> State: if self.discrete_rssm: raise NotImplementedError("discrete rssm was not adopted for slot attention") else: - determ_post, diff = determ_prior, 0 + determ_post, diff = determ_prior.clone(), 0 determ_post = determ_post.reshape(prev_state.determ.shape) @@ -177,7 +186,7 @@ def predict_next(self, prev_state: State, action) -> State: # Experiment, when only stochastic part is affected and deterministic is not touched # We keep flow of gradients through determ block, but updating it with stochastic part for _ in range(self.attention_block_num): - q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) # + q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) if self.symmetric_qk: k = q qk = torch.einsum('lbih,lbjh->lbij', q, k) @@ -188,7 +197,7 @@ def predict_next(self, prev_state: State, action) -> State: coeff = self.attention_scheduler.val attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) - updates = torch.einsum('lbij,lbjh->lbih', qk, v) + updates = torch.einsum('lbij,lbjh->lbih', attn, v) determ_post = determ_post + self.fc(self.fc_norm(updates)) self.last_attention = attn.mean(dim=1).squeeze() @@ -196,14 +205,14 @@ def predict_next(self, prev_state: State, action) -> State: # used for KL divergence predicted_stoch_logits = self.estimate_stochastic_latent(determ_post.reshape(determ_prior.shape)).reshape(prev_state.stoch_logits.shape) # Size is 1 x B x slots_num x ... - return State(determ_post, - predicted_stoch_logits), diff + return State(determ_prior.reshape(prev_state.determ.shape), + predicted_stoch_logits, pos_enc=prev_state.pos_enc, determ_updated=determ_post), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' return State( - prior.determ, + prior.determ_updated, self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( - 1, 2).reshape(prior.stoch_logits.shape)) + 1, 2).reshape(prior.stoch_logits.shape), pos_enc=prior.pos_enc) def forward(self, h_prev: State, embed, action) -> tuple[State, State]: """ diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 3428cab..c5f7ebd 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, Normalizer, View +from rl_sandbox.agents.dreamer import Dist, Normalizer, View, get_position_encoding from rl_sandbox.agents.dreamer.rssm_slots_attention import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder from rl_sandbox.utils.dists import DistLayer @@ -103,6 +103,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) + self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -185,7 +186,8 @@ def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): batch_size, self.slots_num, self.latent_classes * self.latent_dim, - device=device)), None + device=device), + self.pos_enc.unsqueeze(0).unsqueeze(0)), None def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) @@ -262,7 +264,7 @@ def KL(dist1, dist2): self.last_attn = torch.zeros((self.slots_num, self.slots_num), device=a_c.device) - prev_slots = self.slot_attention.generate_initial(b // self.cluster_size).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) + prev_slots = (self.slot_attention.generate_initial(b // self.cluster_size)).repeat(self.cluster_size, 1, 1, 1).transpose(0, 1) slots_c = self.slot_attention(pre_slot_features_c.flatten(0, 1), prev_slots.flatten(0, 1)).reshape(b // self.cluster_size, self.cluster_size, self.slots_num, -1) for t in range(self.cluster_size): diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 33576bf..1a0ebe7 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -6,7 +6,7 @@ from torch import nn from torch.nn import functional as F -from rl_sandbox.agents.dreamer import Dist, Normalizer, View +from rl_sandbox.agents.dreamer import Dist, Normalizer, View, get_position_encoding from rl_sandbox.agents.dreamer.rssm_slots_combined import RSSM, State from rl_sandbox.agents.dreamer.vision import Decoder, Encoder from rl_sandbox.utils.dists import DistLayer @@ -98,18 +98,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) - def getPositionEncoding(seq_len, d, n=10000): - import numpy as np - P = np.zeros((seq_len, d)) - for k in range(seq_len): - for i in np.arange(int(d/2)): - denominator = np.power(n, 2*i/d) - P[k, 2*i] = np.sin(k/denominator) - P[k, 2*i+1] = np.cos(k/denominator) - return P - self.register_buffer('pos_enc', torch.from_numpy(getPositionEncoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) - # self.state_emb = nn.Embedding(slots_num, self.state_size // slots_num) - # self.slot_emb = nn.Embedding(slots_num, self.n_dim) + self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: @@ -148,10 +137,6 @@ def getPositionEncoding(seq_len, d, n=10000): layer_norm=layer_norm, final_activation=DistLayer('binary')) self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) - self.register_buffer('slot_indexer', torch.linspace(0, - self.slots_num-1, - self.slots_num, - dtype=torch.long)) def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: match self.mask_combination: @@ -198,7 +183,6 @@ def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): self.latent_classes * self.latent_dim, device=device), self.pos_enc.unsqueeze(0).unsqueeze(0)), None - # self.state_emb(self.slot_indexer).unsqueeze(0).unsqueeze(0)), None def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) @@ -208,7 +192,6 @@ def predict_next(self, prev_state: State, action): discount_factors = self.discount_predictor(prior.combined).sample() else: discount_factors = torch.ones_like(reward) - return prior, reward, discount_factors def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, torch.Tensor]]) -> t.Tuple[State, torch.Tensor]: @@ -229,7 +212,6 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t _, posterior, _ = self.recurrent_model.forward(state, slots_t.unsqueeze(0), action) - return posterior, slots_t def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index bf7f8b3..626d46c 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -7,18 +7,10 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:1 - -agent: - world_model: - decode_vit: true - vit_img_size: 224 - vit_l2_ratio: 0.5 - kl_loss_scale: 1.0 - kl_free_nats: 1.0 +device_type: cuda logger: - message: New decoder, fixed KL + message: Crafter default log_grads: false training: @@ -42,12 +34,12 @@ debug: hydra: #mode: MULTIRUN mode: RUN - #launcher: - # n_jobs: 1 + launcher: + n_jobs: 1 #sweeper: - #params: - # agent.world_model.vit_img_size: 224 - # agent.world_model.kl_loss_scale: 1 - # agent.world_model.kl_free_nats: 0.0,1.0 - # agent.world_model.vit_l2_ratio: 0.5 + # params: + # agent.world_model._target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel,rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel + # agent.world_model.vit_l2_ratio: 0.1,0.5 + # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 + # agent.world_model.vit_l2_ratio: 0.1,0.9 diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 5ed361b..77905e5 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -7,7 +7,7 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:0 +device_type: cuda:1 agent: world_model: @@ -21,7 +21,7 @@ agent: kl_free_nats: 1.0 logger: - message: Attention, without dino, fixed KL, 1 nat, no prev slot + message: Attention, without dino, Fixed attn, pos enc, determ_updated log_grads: false training: From b87c1940ba15292e0cdd2f49eb6b13ea42d0734e Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 18 Jul 2023 13:12:33 +0100 Subject: [PATCH 088/106] Fix determ updated --- .../agents/dreamer/rssm_slots_attention.py | 8 +++---- .../agents/dreamer/rssm_slots_combined.py | 2 +- .../agent/dreamer_v2_slotted_attention.yaml | 2 +- rl_sandbox/config/config_attention.yaml | 2 +- rl_sandbox/metrics.py | 21 +++++++++---------- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index 88681ee..c412571 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -22,7 +22,7 @@ def flatten(self): return State(self.determ.flatten(0, 1).unsqueeze(0), self.stoch_logits.flatten(0, 1).unsqueeze(0), self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None, - self.pos_enc.detach() if self.pos_enc is not None else None) + self.pos_enc if self.pos_enc is not None else None) def detach(self): return State(self.determ.detach(), @@ -178,7 +178,7 @@ def predict_next(self, prev_state: State, action) -> State: if self.discrete_rssm: raise NotImplementedError("discrete rssm was not adopted for slot attention") else: - determ_post, diff = determ_prior.clone(), 0 + determ_post, diff = determ_prior, 0 determ_post = determ_post.reshape(prev_state.determ.shape) @@ -210,8 +210,8 @@ def predict_next(self, prev_state: State, action) -> State: def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' return State( - prior.determ_updated, - self.stoch_net(torch.concat([prior.determ, embed], dim=-1)).flatten( + prior.determ, + self.stoch_net(torch.concat([prior.determ_updated, embed], dim=-1)).flatten( 1, 2).reshape(prior.stoch_logits.shape), pos_enc=prior.pos_enc) def forward(self, h_prev: State, embed, action) -> tuple[State, State]: diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py index 5716190..39cda6c 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_combined.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -19,7 +19,7 @@ def flatten(self): return State(self.determ.flatten(0, 1).unsqueeze(0), self.stoch_logits.flatten(0, 1).unsqueeze(0), self.stoch_.flatten(0, 1).unsqueeze(0) if self.stoch_ is not None else None, - self.pos_enc.detach() if self.pos_enc is not None else None) + self.pos_enc if self.pos_enc is not None else None) def detach(self): return State(self.determ.detach(), diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index 8402bf2..48cf106 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -17,4 +17,4 @@ world_model: full_qk_from: 4e4 symmetric_qk: true - attention_block_num: 1 + attention_block_num: 3 diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 77905e5..134e222 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -21,7 +21,7 @@ agent: kl_free_nats: 1.0 logger: - message: Attention, without dino, Fixed attn, pos enc, determ_updated + message: Attention, without dino, 3 attn layer num, remove clone, fix updated log_grads: false training: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 682e53e..1a3bd2b 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -20,7 +20,7 @@ def on_step(self, logger): pass def on_episode(self, logger): - pass + self.episode += 1 def on_val(self, logger, rollouts: list[Rollout], global_step: int): metrics = self.calculate_metrics(rollouts) @@ -28,7 +28,6 @@ def on_val(self, logger, rollouts: list[Rollout], global_step: int): if self.log_video: video = rollouts[0].obs.unsqueeze(0) logger.add_video('val/visualization', self.agent.unprocess_obs(video), global_step) - self.episode += 1 def calculate_metrics(self, rollouts: list[Rollout]): return { @@ -68,6 +67,7 @@ def on_step(self, logger): def on_episode(self, logger): latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() self.latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) + self.action_hist = (self.agent._action_probs / self.stored_steps).detach().cpu().numpy() self.reset_ep() self.episode += 1 @@ -80,16 +80,15 @@ def on_val(self, logger, rollouts: list[Rollout], global_step: int): # if discrete action space if self.agent.is_discrete: - action_hist = (self.agent._action_probs / self.stored_steps).detach().cpu().numpy() fig = plt.Figure() ax = fig.add_axes([0, 0, 1, 1]) - ax.bar(np.arange(self.agent.actions_num), action_hist) + ax.bar(np.arange(self.agent.actions_num), self.action_hist) logger.add_figure('val/action_probs', fig, self.episode) else: # log mean +- std pass - logger.add_image('val/latent_probs', self.latent_hist, global_step, dataformats='HW') - logger.add_image('val/latent_probs_sorted', np.sort(self.latent_hist, axis=1), global_step, dataformats='HW') + logger.add_image('val/latent_probs', np.expand_dims(self.latent_hist, 0), global_step, dataformats='HW') + logger.add_image('val/latent_probs_sorted', np.expand_dims(np.sort(self.latent_hist, axis=1), 0), global_step, dataformats='HW') def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): # obs = self.agent.preprocess_obs(obs) @@ -174,13 +173,13 @@ def on_val(self, logger, rollouts: list[Rollout], global_step: int): wm = self.agent.world_model if hasattr(wm.recurrent_model, 'last_attention'): - logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention, self.episode, dataformats='HW') + logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention.unsqueeze(0), global_step, dataformats='HW') - logger.add_image('val/slot_attention_mu', self.mu_hist/self.mu_hist.max(), self.episode, dataformats='HW') - logger.add_image('val/slot_attention_sigma', self.sigma_hist/self.sigma_hist.max(), self.episode, dataformats='HW') + logger.add_image('val/slot_attention_mu', (self.mu_hist/self.mu_hist.max()).unsqueeze(0), global_step, dataformats='HW') + logger.add_image('val/slot_attention_sigma', (self.sigma_hist/self.sigma_hist.max()).unsqueeze(0), global_step, dataformats='HW') - logger.add_scalar('val/slot_attention_mu_diff_max', self.mu_hist.max(), self.episode) - logger.add_scalar('val/slot_attention_sigma_diff_max', self.sigma_hist.max(), self.episode) + logger.add_scalar('val/slot_attention_mu_diff_max', self.mu_hist.max(), global_step) + logger.add_scalar('val/slot_attention_sigma_diff_max', self.sigma_hist.max(), global_step) def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): # obs = torch.from_numpy(obs.copy()).to(self.agent.device) From 0e610b2eaf7f294632dacf691464a39a5f62f01a Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 18 Jul 2023 13:14:26 +0100 Subject: [PATCH 089/106] Added attention block before image update --- .../agents/dreamer/rssm_slots_attention.py | 20 +++++++++++++++++++ rl_sandbox/config/config_attention.yaml | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index c412571..65874d4 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -153,6 +153,13 @@ def __init__(self, self.att_scale = hidden_size**(-0.5) self.eps = 1e-8 + self.hidden_attention_proj_obs = nn.Linear(embed_size, 2*embed_size) + self.hidden_attention_proj_obs_state = nn.Linear(hidden_size, embed_size) + self.pre_norm_obs = nn.LayerNorm(embed_size) + + self.fc_obs = nn.Linear(embed_size, embed_size) + self.fc_norm_obs = nn.LayerNorm(embed_size) + def on_train_step(self): self.attention_scheduler.step() @@ -209,6 +216,19 @@ def predict_next(self, prev_state: State, action) -> State: predicted_stoch_logits, pos_enc=prev_state.pos_enc, determ_updated=determ_post), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' + for _ in range(self.attention_block_num): + k = self.hidden_attention_proj_obs_state(self.pre_norm(prior.determ_updated)) + q, v = self.hidden_attention_proj_obs(self.pre_norm_obs(embed)).chunk(2, dim=-1) + qk = torch.einsum('lbih,lbjh->lbij', q, k) + + # TODO: Use Gumbel Softmax + attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) + attn = attn / attn.sum(dim=-1, keepdim=True) + + updates = torch.einsum('lbij,lbjh->lbih', attn, v) + # TODO: Try just using updates instead of embed + embed = embed + self.fc_obs(self.fc_norm_obs(updates)) + return State( prior.determ, self.stoch_net(torch.concat([prior.determ_updated, embed], dim=-1)).flatten( diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 134e222..91214f4 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -21,7 +21,7 @@ agent: kl_free_nats: 1.0 logger: - message: Attention, without dino, 3 attn layer num, remove clone, fix updated + message: Attention, without dino, add attention for embed choosing log_grads: false training: From e80af480ce40128fd814994efdc0e970de76e723 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 18 Jul 2023 21:33:50 +0100 Subject: [PATCH 090/106] Added identity decay for embed attention --- .../agents/dreamer/rssm_slots_attention.py | 24 ++++++++++--------- rl_sandbox/config/config_attention.yaml | 2 +- rl_sandbox/metrics.py | 1 + 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index 65874d4..4890a9a 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -153,7 +153,7 @@ def __init__(self, self.att_scale = hidden_size**(-0.5) self.eps = 1e-8 - self.hidden_attention_proj_obs = nn.Linear(embed_size, 2*embed_size) + self.hidden_attention_proj_obs = nn.Linear(embed_size, embed_size) self.hidden_attention_proj_obs_state = nn.Linear(hidden_size, embed_size) self.pre_norm_obs = nn.LayerNorm(embed_size) @@ -216,18 +216,20 @@ def predict_next(self, prev_state: State, action) -> State: predicted_stoch_logits, pos_enc=prev_state.pos_enc, determ_updated=determ_post), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' - for _ in range(self.attention_block_num): - k = self.hidden_attention_proj_obs_state(self.pre_norm(prior.determ_updated)) - q, v = self.hidden_attention_proj_obs(self.pre_norm_obs(embed)).chunk(2, dim=-1) - qk = torch.einsum('lbih,lbjh->lbij', q, k) + k = self.hidden_attention_proj_obs_state(self.pre_norm(prior.determ_updated)) + q = self.hidden_attention_proj_obs(self.pre_norm_obs(embed)) + qk = torch.einsum('lbih,lbjh->lbij', q, k) - # TODO: Use Gumbel Softmax - attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) - attn = attn / attn.sum(dim=-1, keepdim=True) + # TODO: Use Gumbel Softmax + attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) + attn = attn / attn.sum(dim=-1, keepdim=True) - updates = torch.einsum('lbij,lbjh->lbih', attn, v) - # TODO: Try just using updates instead of embed - embed = embed + self.fc_obs(self.fc_norm_obs(updates)) + # TODO: Maybe make this a learnable parameter ? + coeff = min((self.attention_scheduler.val * 5), 1.0) + attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) + + embed = torch.einsum('lbij,lbjh->lbih', attn, embed) + self.embed_attn = attn.squeeze() return State( prior.determ, diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 91214f4..8232147 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -21,7 +21,7 @@ agent: kl_free_nats: 1.0 logger: - message: Attention, without dino, add attention for embed choosing + message: Attention, without dino, Attn X Embed with Identity attn decay log_grads: false training: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 1a3bd2b..24f656f 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -174,6 +174,7 @@ def on_val(self, logger, rollouts: list[Rollout], global_step: int): if hasattr(wm.recurrent_model, 'last_attention'): logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention.unsqueeze(0), global_step, dataformats='HW') + logger.add_image('val/embed_attention', wm.recurrent_model.embed_attn.unsqueeze(0), global_step, dataformats='HW') logger.add_image('val/slot_attention_mu', (self.mu_hist/self.mu_hist.max()).unsqueeze(0), global_step, dataformats='HW') logger.add_image('val/slot_attention_sigma', (self.sigma_hist/self.sigma_hist.max()).unsqueeze(0), global_step, dataformats='HW') From 9bf4054897a897c97a84e2486cda2dc4ae8e7f67 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 18 Jul 2023 23:05:24 +0100 Subject: [PATCH 091/106] Added vit error visualization --- .../agents/dreamer/rssm_slots_attention.py | 32 +++++++++---------- .../dreamer/world_model_slots_attention.py | 4 +-- rl_sandbox/config/config_attention.yaml | 14 ++++---- rl_sandbox/metrics.py | 32 +++++++++++++------ 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index 4890a9a..582b8c2 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -153,12 +153,12 @@ def __init__(self, self.att_scale = hidden_size**(-0.5) self.eps = 1e-8 - self.hidden_attention_proj_obs = nn.Linear(embed_size, embed_size) - self.hidden_attention_proj_obs_state = nn.Linear(hidden_size, embed_size) - self.pre_norm_obs = nn.LayerNorm(embed_size) + # self.hidden_attention_proj_obs = nn.Linear(embed_size, embed_size) + # self.hidden_attention_proj_obs_state = nn.Linear(hidden_size, embed_size) + # self.pre_norm_obs = nn.LayerNorm(embed_size) - self.fc_obs = nn.Linear(embed_size, embed_size) - self.fc_norm_obs = nn.LayerNorm(embed_size) + # self.fc_obs = nn.Linear(embed_size, embed_size) + # self.fc_norm_obs = nn.LayerNorm(embed_size) def on_train_step(self): self.attention_scheduler.step() @@ -216,20 +216,20 @@ def predict_next(self, prev_state: State, action) -> State: predicted_stoch_logits, pos_enc=prev_state.pos_enc, determ_updated=determ_post), diff def update_current(self, prior: State, embed) -> State: # Dreamer 'obs_out' - k = self.hidden_attention_proj_obs_state(self.pre_norm(prior.determ_updated)) - q = self.hidden_attention_proj_obs(self.pre_norm_obs(embed)) - qk = torch.einsum('lbih,lbjh->lbij', q, k) + # k = self.hidden_attention_proj_obs_state(self.pre_norm(prior.determ_updated)) + # q = self.hidden_attention_proj_obs(self.pre_norm_obs(embed)) + # qk = torch.einsum('lbih,lbjh->lbij', q, k) - # TODO: Use Gumbel Softmax - attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) - attn = attn / attn.sum(dim=-1, keepdim=True) + # # TODO: Use Gumbel Softmax + # attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) + # attn = attn / attn.sum(dim=-1, keepdim=True) - # TODO: Maybe make this a learnable parameter ? - coeff = min((self.attention_scheduler.val * 5), 1.0) - attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) + # # TODO: Maybe make this a learnable parameter ? + # coeff = min((self.attention_scheduler.val * 5), 1.0) + # attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) - embed = torch.einsum('lbij,lbjh->lbih', attn, embed) - self.embed_attn = attn.squeeze() + # embed = torch.einsum('lbij,lbjh->lbih', attn, embed) + # self.embed_attn = attn.squeeze() return State( prior.determ, diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index c5f7ebd..219fdb5 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -50,7 +50,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.vit_img_size = vit_img_size self.per_slot_rec_loss = per_slot_rec_loss - self.n_dim = 192 + self.n_dim = 384 self.recurrent_model = RSSM( latent_dim, @@ -98,7 +98,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4], - channel_step=48, + channel_step=48 * (self.n_dim // 192), double_conv=True, flatten_output=False) diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 8232147..2c726b4 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -7,21 +7,21 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:1 +device_type: cuda:0 agent: world_model: encode_vit: false - decode_vit: false - #vit_img_size: 224 - #vit_l2_ratio: 0.5 + decode_vit: true + vit_img_size: 224 + vit_l2_ratio: 1.0 slots_iter_num: 3 slots_num: 6 kl_loss_scale: 1.0 kl_free_nats: 1.0 logger: - message: Attention, without dino, Attn X Embed with Identity attn decay + message: Attention, only dino, determ_update, 384 slot n_dim log_grads: false training: @@ -36,8 +36,8 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - #- _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator - - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + #- _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator _partial_: true debug: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 24f656f..ad2b769 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -259,7 +259,7 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) class SlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): - def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + def _generate_video(self, obs: list[Observation], actions: list[Action], d_feats: list[torch.Tensor], update_num: int): # obs = torch.from_numpy(obs.copy()).to(self.agent.device) # obs = self.agent.preprocess_obs(obs) # actions = self.agent.from_np(actions) @@ -268,13 +268,15 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ video = [] slots_video = [] vit_slots_video = [] + vit_mean_err_video = [] + vit_max_err_video = [] rews = [] vit_size = self.agent.world_model.vit_size state = None prev_slots = None - for idx, (o, a) in enumerate(list(zip(obs, actions))): + for idx, (o, a, d_feat) in enumerate(list(zip(obs, actions, d_feats))): if idx > update_num: break state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) @@ -286,8 +288,9 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) - _, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(state.combined_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) + decoded_dino = (decoded_dino_feats * vit_mask).sum(dim=1) upscale = tv.transforms.Resize(64, antialias=True) upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) @@ -297,6 +300,8 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ video.append(self.agent.unprocess_obs(video_r)) slots_video.append(self.agent.unprocess_obs(decoded_imgs)) vit_slots_video.append(self.agent.unprocess_obs(per_slot_vit/upscaled_mask.max())) + vit_mean_err_video.append(((d_feat.reshape(decoded_dino.shape) - decoded_dino)**2).mean(dim=1)) + vit_max_err_video.append(((d_feat.reshape(decoded_dino.shape) - decoded_dino)**2).max(dim=1).values) rews = torch.Tensor(rews).to(obs.device) @@ -304,24 +309,26 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) rews = torch.cat([rews, rews_2[1:].squeeze()]) - # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode decoded_imgs, masks = self.agent.world_model.image_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) img_mask = self.agent.world_model.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask video_r = torch.sum(decoded_imgs, dim=1) - _, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(states.combined_slots[1:].flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) vit_mask = F.softmax(vit_masks, dim=1) + decoded_dino = (decoded_dino_feats * vit_mask).sum(dim=1) + upscale = tv.transforms.Resize(64, antialias=True) upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) per_slot_vit = (upscaled_mask.unsqueeze(2) * obs[update_num+1:].to(self.agent.device).unsqueeze(1)) - # per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) video.append(self.agent.unprocess_obs(video_r)) slots_video.append(self.agent.unprocess_obs(decoded_imgs)) vit_slots_video.append(self.agent.unprocess_obs(per_slot_vit/torch.amax(upscaled_mask, dim=(1,2,3)).view(-1, 1, 1, 1, 1))) + vit_mean_err_video.append(((d_feats[update_num+1:].reshape(decoded_dino.shape) - decoded_dino)**2).mean(dim=1)) + vit_max_err_video.append(((d_feats[update_num+1:].reshape(decoded_dino.shape) - decoded_dino)**2).max(dim=1).values) - return torch.cat(video), rews, torch.cat(slots_video), torch.cat(vit_slots_video) + return torch.cat(video), rews, torch.cat(slots_video), torch.cat(vit_slots_video), torch.cat(vit_mean_err_video).unsqueeze(0), torch.cat(vit_max_err_video).unsqueeze(0) def viz_log(self, rollout, logger, epoch_num): rollout = rollout.to(device=self.agent.device) @@ -334,12 +341,17 @@ def viz_log(self, rollout, logger, epoch_num): real_rewards = [rollout.rewards[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces] - videos_r, imagined_rewards, slots_video, vit_masks_video = zip(*[self._generate_video(obs_0, a_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0 in zip( + videos_r, imagined_rewards, slots_video, vit_masks_video, vit_mean_err_video, vit_max_err_video = zip(*[self._generate_video(obs_0, a_0, d_feat_0, update_num=self.agent.imagination_horizon//3) for obs_0, a_0, d_feat_0 in zip( [rollout.obs[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], - [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) + [rollout.actions[idx:idx+ self.agent.imagination_horizon] for idx in init_indeces], + [rollout.additional_data['d_features'][idx:idx+ self.agent.imagination_horizon] for idx in init_indeces]) ]) videos_r = torch.cat(videos_r, dim=3) + vit_mean_err_video = torch.cat(vit_mean_err_video, dim=3) + vit_max_err_video = torch.cat(vit_max_err_video, dim=3) + vit_mean_err_video = (vit_mean_err_video/vit_mean_err_video.max() * 255.0).to(dtype=torch.uint8) + vit_max_err_video = (vit_max_err_video/vit_max_err_video.max() * 255.0).to(dtype=torch.uint8) videos_comparison = torch.cat([videos, videos_r, (torch.abs(videos.float() - videos_r.float() + 1)/2).to(dtype=torch.uint8)], dim=2).unsqueeze(0) @@ -354,6 +366,8 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) logger.add_video('val/dreamed_vit_masks', vit_masks_video, epoch_num) + logger.add_video('val/dreamed_vit_masks', vit_mean_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) + logger.add_video('val/dreamed_vit_masks', vit_max_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) # FIXME: rewrite sum(...) as (...).sum() rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() From 8c8e17287f7a023492c63c9c1d350b06b49656cb Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 18 Jul 2023 23:08:30 +0100 Subject: [PATCH 092/106] fixup! Added vit error visualization --- rl_sandbox/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index ad2b769..63f758e 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -366,8 +366,8 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_video('val/dreamed_rollout', videos_comparison, epoch_num) logger.add_video('val/dreamed_slots', slots_video, epoch_num) logger.add_video('val/dreamed_vit_masks', vit_masks_video, epoch_num) - logger.add_video('val/dreamed_vit_masks', vit_mean_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) - logger.add_video('val/dreamed_vit_masks', vit_max_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) + logger.add_video('val/vit_mean_err', vit_mean_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) + logger.add_video('val/vit_max_err', vit_max_err_video.detach().cpu().unsqueeze(2).repeat(1, 1, 3, 1, 1), epoch_num) # FIXME: rewrite sum(...) as (...).sum() rewards_err = torch.Tensor([torch.abs(sum(imagined_rewards[i]) - real_rewards[i].sum()) for i in range(len(imagined_rewards))]).mean() From bdd8402d273c1979f88955c5f998e3a3b38ac752 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Wed, 19 Jul 2023 14:23:44 +0100 Subject: [PATCH 093/106] Higher kl loss --- rl_sandbox/config/config_attention.yaml | 4 ++-- rl_sandbox/metrics.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 2c726b4..61287bb 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -17,11 +17,11 @@ agent: vit_l2_ratio: 1.0 slots_iter_num: 3 slots_num: 6 - kl_loss_scale: 1.0 + kl_loss_scale: 5.0 kl_free_nats: 1.0 logger: - message: Attention, only dino, determ_update, 384 slot n_dim + message: Attention, only dino, determ_update, 384 slot n_dim, kl=5 log_grads: false training: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 63f758e..4ae5485 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -174,7 +174,7 @@ def on_val(self, logger, rollouts: list[Rollout], global_step: int): if hasattr(wm.recurrent_model, 'last_attention'): logger.add_image('val/mixer_attention', wm.recurrent_model.last_attention.unsqueeze(0), global_step, dataformats='HW') - logger.add_image('val/embed_attention', wm.recurrent_model.embed_attn.unsqueeze(0), global_step, dataformats='HW') + # logger.add_image('val/embed_attention', wm.recurrent_model.embed_attn.unsqueeze(0), global_step, dataformats='HW') logger.add_image('val/slot_attention_mu', (self.mu_hist/self.mu_hist.max()).unsqueeze(0), global_step, dataformats='HW') logger.add_image('val/slot_attention_sigma', (self.sigma_hist/self.sigma_hist.max()).unsqueeze(0), global_step, dataformats='HW') From 276a872ff30f837d1fac3389355c7e69f2a5b9d6 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Thu, 20 Jul 2023 15:57:32 +0100 Subject: [PATCH 094/106] Increase kl loss and batch size --- .../config/agent/dreamer_v2_slotted_attention.yaml | 2 +- rl_sandbox/config/config.yaml | 10 +++++++++- rl_sandbox/config/config_attention.yaml | 4 ++-- rl_sandbox/config/training/crafter.yaml | 2 +- rl_sandbox/metrics.py | 9 ++++++--- 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index 48cf106..b3b5780 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -16,5 +16,5 @@ world_model: vit_l2_ratio: 0.1 full_qk_from: 4e4 - symmetric_qk: true + symmetric_qk: false attention_block_num: 3 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 626d46c..a973249 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -9,8 +9,16 @@ defaults: seed: 42 device_type: cuda +agent: + world_model: + decode_vit: true + vit_img_size: 224 + vit_l2_ratio: 1.0 + kl_loss_scale: 5.0 + kl_free_nats: 1.0 + logger: - message: Crafter default + message: Dreamer with dino, kl=5 log_grads: false training: diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 61287bb..67549f6 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -17,11 +17,11 @@ agent: vit_l2_ratio: 1.0 slots_iter_num: 3 slots_num: 6 - kl_loss_scale: 5.0 + kl_loss_scale: 20.0 kl_free_nats: 1.0 logger: - message: Attention, only dino, determ_update, 384 slot n_dim, kl=5 + message: Attention, only dino, kl=20, removed symmetric log_grads: false training: diff --git a/rl_sandbox/config/training/crafter.yaml b/rl_sandbox/config/training/crafter.yaml index ba1943f..3ff03fa 100644 --- a/rl_sandbox/config/training/crafter.yaml +++ b/rl_sandbox/config/training/crafter.yaml @@ -1,6 +1,6 @@ steps: 1e6 prefill: 10000 -batch_size: 16 +batch_size: 50 f16_precision: false pretrain: 1 prioritize_ends: true diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 4ae5485..ef95bec 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -19,9 +19,12 @@ def __init__(self, agent: 'DreamerV2', log_video: bool = False): def on_step(self, logger): pass - def on_episode(self, logger): + def on_episode(self, logger, rollout, global_step: int): self.episode += 1 + metrics = self.calculate_metrics([rollout]) + logger.log(metrics, global_step, mode='train') + def on_val(self, logger, rollouts: list[Rollout], global_step: int): metrics = self.calculate_metrics(rollouts) logger.log(metrics, global_step, mode='val') @@ -64,7 +67,7 @@ def on_step(self, logger): self._action_probs += self._action_probs self._latent_probs += self.agent._state.stoch_dist.base_dist.probs.squeeze().mean(dim=0) - def on_episode(self, logger): + def on_episode(self, logger, rollout, global_step: int): latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() self.latent_hist = ((latent_hist / latent_hist.max() * 255.0 )).astype(np.uint8) self.action_hist = (self.agent._action_probs / self.stored_steps).detach().cpu().numpy() @@ -153,7 +156,7 @@ def on_step(self, logger): self._action_probs += self._action_probs self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze().mean(dim=0) - def on_episode(self, logger): + def on_episode(self, logger, rollout): wm = self.agent.world_model mu = wm.slot_attention.slots_mu From 706e9731e46d337da7b2f2708717ce997d889367 Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 21 Jul 2023 17:06:23 +0100 Subject: [PATCH 095/106] Added score calculation for crafter --- rl_sandbox/config/config.yaml | 2 + rl_sandbox/crafter_metrics.py | 78 +++++++++++++++++++++++++++++++++++ rl_sandbox/metrics.py | 4 +- rl_sandbox/train.py | 2 +- 4 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 rl_sandbox/crafter_metrics.py diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index a973249..b681dda 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -35,6 +35,8 @@ validation: _partial_: true - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator _partial_: true + - _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator + _partial_: true debug: profiler: false diff --git a/rl_sandbox/crafter_metrics.py b/rl_sandbox/crafter_metrics.py new file mode 100644 index 0000000..a1ba450 --- /dev/null +++ b/rl_sandbox/crafter_metrics.py @@ -0,0 +1,78 @@ +import json +import pathlib +import warnings +import collections +from pathlib import Path + +import numpy as np + +from rl_sandbox.utils.replay_buffer import Rollout + +def compute_scores(percents): + # Geometric mean with an offset of 1%. + assert (0 <= percents).all() and (percents <= 100).all() + if (percents <= 1.0).all(): + print('Warning: The input may not be in the right range.') + with warnings.catch_warnings(): # Empty seeds become NaN. + warnings.simplefilter('ignore', category=RuntimeWarning) + scores = np.exp(np.nanmean(np.log(1 + percents), -1)) - 1 + return scores + + +def load_stats(filename, budget): + steps = 0 + rewards = [] + lengths = [] + achievements = collections.defaultdict(list) + for line in filename.read_text().split('\n'): + if not line.strip(): + continue + episode = json.loads(line) + steps += episode['length'] + if steps > budget: + break + lengths.append(episode['length']) + for key, value in episode.items(): + if key.startswith('achievement_'): + achievements[key].append(value) + unlocks = int(np.sum([(v[-1] >= 1) for v in achievements.values()])) + health = -0.9 + rewards.append(unlocks + health) + return rewards, lengths, achievements + + +class CrafterMetricsEvaluator(): + def __init__(self, agent: 'DreamerV2'): + self.agent = agent + self.episode = 0 + + def on_val(self, logger, rollouts: list[Rollout], global_step: int): + if logger.log_dir() is None: + return + budget = 1e6 + stats_file = Path(logger.log_dir()) / "stats.jsonl" + _, lengths, achievements = load_stats(stats_file, budget) + + tasks = list(achievements.keys()) + + xs = np.cumsum(lengths).tolist() + episodes = (np.array(xs) <= budget).sum() + percents = np.empty((len(achievements))) + percents[:] = np.nan + for key, values in achievements.items(): + k = tasks.index(key) + percent = 100 * (np.array(values[:episodes]) >= 1).mean() + percents[k] = percent + + score = compute_scores(percents) + + logger.log({"score": score}, global_step, mode='val') + + def on_step(self, logger): + pass + + def on_episode(self, logger, rollout, global_step: int): + pass + + + diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index ef95bec..dc48443 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -156,7 +156,7 @@ def on_step(self, logger): self._action_probs += self._action_probs self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze().mean(dim=0) - def on_episode(self, logger, rollout): + def on_episode(self, logger, rollout, global_step: int): wm = self.agent.world_model mu = wm.slot_attention.slots_mu @@ -165,7 +165,7 @@ def on_episode(self, logger, rollout): self.sigma_hist = torch.mean((sigma - sigma.squeeze(0).unsqueeze(1)) ** 2, dim=-1) - super().on_episode(logger) + super().on_episode(logger, rollout, global_step) def on_val(self, logger, rollouts: list[Rollout], global_step: int): super().on_val(logger, rollouts, global_step) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 596b4f1..5c61b59 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -139,7 +139,7 @@ def main(cfg: DictConfig): pbar.update(cfg.env.repeat_action_num) for metric in metrics: - metric.on_episode(logger) + metric.on_episode(logger, buff.rollouts[-1], global_step) # FIXME: find more appealing solution ### Validation From 70dcd58102c27e3892df27c68f51cbb29070280d Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 23 Jul 2023 23:29:00 +0100 Subject: [PATCH 096/106] fixing differences with baselines --- rl_sandbox/agents/dreamer/ac.py | 3 +-- rl_sandbox/agents/dreamer/vision.py | 16 ++++++------- rl_sandbox/agents/dreamer/world_model.py | 2 +- rl_sandbox/config/agent/dreamer_v2.yaml | 4 ++-- .../config/agent/dreamer_v2_crafter.yaml | 2 +- .../agent/dreamer_v2_slotted_attention.yaml | 12 +++++++--- rl_sandbox/config/config.yaml | 23 ++++++++++--------- rl_sandbox/config/config_attention.yaml | 4 ++-- rl_sandbox/config/training/crafter.yaml | 2 +- rl_sandbox/metrics.py | 2 +- rl_sandbox/utils/replay_buffer.py | 10 ++++---- 11 files changed, 42 insertions(+), 38 deletions(-) diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py index ca8a4b8..bfd0a12 100644 --- a/rl_sandbox/agents/dreamer/ac.py +++ b/rl_sandbox/agents/dreamer/ac.py @@ -115,7 +115,6 @@ def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, baseline: torch.Ten losses = {} metrics = {} action_dists = self.actor(zs.detach()) - # baseline = advantage = (vs - baseline).detach() losses['loss_actor_reinforce'] = -(self.rho * action_dists.log_prob( actions.detach()).unsqueeze(2) * discount_factors * advantage).mean() @@ -135,7 +134,7 @@ def calculate_entropy(dist): 'loss_actor_dynamics_backprop'] + losses['loss_actor_entropy'] # mean and std are estimated statistically as tanh transformation is used - sample = action_dists.rsample((128, )) + sample = action_dists.rsample((128,)) act_avg = sample.mean(0) metrics['actor/avg_val'] = act_avg.mean() # metrics['actor/mode_val'] = action_dists.mode.mean() diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index cd11d73..f0c13ee 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -19,12 +19,14 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=2)) layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) - if double_conv: - layers.append( - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) - layers.append(norm_layer(1, out_channels)) - layers.append(nn.ELU(inplace=True)) in_channels = out_channels + + for i, k in enumerate(kernel_sizes): + layers.append( + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + if flatten_output: layers.append(nn.Flatten()) self.net = nn.Sequential(*layers) @@ -46,7 +48,7 @@ def __init__(self, super().__init__() layers = [] self.channel_step = channel_step - self.in_channels = 2 **(len(kernel_sizes)-1) * self.channel_step + self.in_channels = 2 **(len(kernel_sizes)+1) * self.channel_step in_channels = self.in_channels self.convin = nn.Linear(input_size, in_channels) self.return_dist = return_dist @@ -91,8 +93,6 @@ def forward(self, X): class ViTDecoder(nn.Module): - # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 3, 5, 3]): - # def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, kernel_sizes=[5, 5, 5, 5, 3]): def __init__(self, input_size, norm_layer: nn.GroupNorm | nn.Identity, diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index ef19d0f..ab4e84c 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -135,7 +135,7 @@ def predict_next(self, prev_state: State, action): reward = self.reward_predictor(prior.combined).mode if self.predict_discount: - discount_factors = self.discount_predictor(prior.combined).sample() + discount_factors = self.discount_predictor(prior.combined).mode else: discount_factors = torch.ones_like(reward) return prior, reward, discount_factors diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index 172aecc..bd74ca0 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -13,9 +13,9 @@ world_model: latent_classes: 32 rssm_dim: 200 discount_loss_scale: 1.0 - kl_loss_scale: 1e1 + kl_loss_scale: 1 kl_loss_balancing: 0.8 - kl_free_nats: 0.00 + kl_free_nats: 1.00 discrete_rssm: false decode_vit: false vit_l2_ratio: 0.5 diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 8839422..31f9bb6 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -6,7 +6,7 @@ clip_rewards: tanh layer_norm: true world_model: - rssm_dim: 1024 + rssm_dim: 2048 predict_discount: true actor: diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index b3b5780..9797857 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -4,9 +4,9 @@ defaults: world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel - rssm_dim: 512 - slots_num: 6 - slots_iter_num: 2 + rssm_dim: 1024 + slots_num: 4 + slots_iter_num: 3 kl_loss_scale: 1.0 encode_vit: false decode_vit: true @@ -18,3 +18,9 @@ world_model: full_qk_from: 4e4 symmetric_qk: false attention_block_num: 3 + +wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index b681dda..eaceed6 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -6,19 +6,20 @@ defaults: - _self_ - override hydra/launcher: joblib -seed: 42 -device_type: cuda - -agent: - world_model: - decode_vit: true - vit_img_size: 224 - vit_l2_ratio: 1.0 - kl_loss_scale: 5.0 - kl_free_nats: 1.0 +seed: 43 +device_type: cuda:1 + +#agent: +# world_model: +# decode_vit: true +# vit_img_size: 224 +# vit_l2_ratio: 1.0 +# kl_loss_scale: 5.0 +# kl_loss_balancing: 0.95 +# kl_free_nats: 1.0 logger: - message: Dreamer with dino, kl=5 + message: Dreamer default, 1e6, free_nats, 2048 log_grads: false training: diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 67549f6..052618e 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -17,11 +17,11 @@ agent: vit_l2_ratio: 1.0 slots_iter_num: 3 slots_num: 6 - kl_loss_scale: 20.0 + kl_loss_scale: 2.0 kl_free_nats: 1.0 logger: - message: Attention, only dino, kl=20, removed symmetric + message: Attention, only dino, kl=20, removed symmetric, add warmup log_grads: false training: diff --git a/rl_sandbox/config/training/crafter.yaml b/rl_sandbox/config/training/crafter.yaml index 3ff03fa..ba1943f 100644 --- a/rl_sandbox/config/training/crafter.yaml +++ b/rl_sandbox/config/training/crafter.yaml @@ -1,6 +1,6 @@ steps: 1e6 prefill: 10000 -batch_size: 50 +batch_size: 16 f16_precision: false pretrain: 1 prioritize_ends: true diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index dc48443..0517ee0 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -65,7 +65,7 @@ def on_step(self, logger): if self.agent.is_discrete: self._action_probs += self._action_probs - self._latent_probs += self.agent._state.stoch_dist.base_dist.probs.squeeze().mean(dim=0) + self._latent_probs += self.agent._state.stoch_dist.base_dist.probs.squeeze(0).mean(dim=0) def on_episode(self, logger, rollout, global_step: int): latent_hist = (self._latent_probs / self.stored_steps).detach().cpu().numpy() diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index e452d80..806f5f1 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -125,7 +125,7 @@ def sample( ) -> RolloutChunks: # NOTE: constant creation of numpy arrays from self.rollout_len seems terrible for me s, a, r, t, is_first, additional = [], [], [], [], [], {} - r_indeces = np.random.choice(len(self.rollouts), batch_size, p=np.array(self.rollouts_len) / self.total_num) + r_indeces = np.random.choice(len(self.rollouts), batch_size) s_indeces = [] for r_idx in r_indeces: rollout, r_len = self.rollouts[r_idx], self.rollouts_len[r_idx] @@ -133,14 +133,12 @@ def sample( assert r_len > cluster_size - 1, "Rollout it too small" max_idx = r_len - cluster_size + 1 if self.prioritize_ends: - s_idx = np.random.choice(max_idx - cluster_size + 1, 1).item() + cluster_size - 1 - else: - s_idx = np.random.choice(max_idx, 1).item() + max_idx += cluster_size + s_idx = min(np.random.randint(max_idx), r_len - cluster_size) s_indeces.append(s_idx) is_first.append(torch.zeros(cluster_size)) - if s_idx == 0: - is_first[-1][0] = 1 + is_first[-1][0] = 1 s.append(rollout.obs[s_idx:s_idx + cluster_size]) a.append(rollout.actions[s_idx:s_idx + cluster_size]) From aaa5dc46a537ca588c8fd17ef9c7cf89e9ff21bd Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 25 Jul 2023 13:16:10 +0000 Subject: [PATCH 097/106] Changes to the decoder --- rl_sandbox/agents/dreamer/vision.py | 17 ++++++------ rl_sandbox/agents/dreamer/world_model.py | 2 +- .../config/agent/dreamer_v2_crafter.yaml | 2 +- rl_sandbox/config/config.yaml | 26 ++++++++++++------- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index f0c13ee..6e3b572 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -71,15 +71,16 @@ def __init__(self, output_padding=0)) layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) - for k in conv_kernel_sizes: - layers.append( - nn.Conv2d(out_channels, - out_channels, - kernel_size=k, - padding='same')) - layers.append(norm_layer(1, out_channels)) - layers.append(nn.ELU(inplace=True)) in_channels = out_channels + + for k in conv_kernel_sizes: + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + layers.append( + nn.Conv2d(output_channels, + output_channels, + kernel_size=k, + padding='same')) self.net = nn.Sequential(*layers) def forward(self, X): diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index ab4e84c..095c87c 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -88,7 +88,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3], + conv_kernel_sizes=[3, 3], channel_step=2*self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim, diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml index 31f9bb6..8839422 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter.yaml @@ -6,7 +6,7 @@ clip_rewards: tanh layer_norm: true world_model: - rssm_dim: 2048 + rssm_dim: 1024 predict_discount: true actor: diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index eaceed6..9fd23c3 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -7,19 +7,25 @@ defaults: - override hydra/launcher: joblib seed: 43 -device_type: cuda:1 +device_type: cuda:0 -#agent: -# world_model: -# decode_vit: true -# vit_img_size: 224 -# vit_l2_ratio: 1.0 -# kl_loss_scale: 5.0 -# kl_loss_balancing: 0.95 -# kl_free_nats: 1.0 +agent: + world_model: + decode_vit: true + vit_img_size: 224 + vit_l2_ratio: 0.5 + kl_loss_scale: 2.0 + kl_loss_balancing: 0.8 + kl_free_nats: 1.0 + + actor_optim: + lr: 2e-4 + + critic_optim: + lr: 2e-4 logger: - message: Dreamer default, 1e6, free_nats, 2048 + message: Dreamer with 0.5 dino 0.8/2, ac 2xlr, 3-3 decoder at the end log_grads: false training: From 4e437f2825900daea7c4564886fe6bfd9bf1ed5d Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 25 Jul 2023 14:26:11 +0100 Subject: [PATCH 098/106] Try newer dino-only attention slotted --- rl_sandbox/agents/dreamer/vision.py | 6 +++--- .../agents/dreamer/world_model_slots_attention.py | 12 ++++++------ .../config/agent/dreamer_v2_crafter_slotted.yaml | 4 ++-- .../agent/dreamer_v2_slotted_attention.yaml | 4 ++-- rl_sandbox/config/config_attention.yaml | 15 +++++++++------ 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index 6e3b572..8435b95 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -6,8 +6,8 @@ class Encoder(nn.Module): def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, channel_step=96, - kernel_sizes=[4, 4, 4], - double_conv=False, + kernel_sizes=[4, 4, 4, 4], + post_conv_num: int = 0, flatten_output=True, in_channels=3, ): @@ -21,7 +21,7 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, layers.append(nn.ELU(inplace=True)) in_channels = out_channels - for i, k in enumerate(kernel_sizes): + for k in range(post_conv_num): layers.append( nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) layers.append(norm_layer(1, out_channels)) diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index 219fdb5..ad3c301 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -97,9 +97,9 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - kernel_sizes=[4, 4, 4], - channel_step=48 * (self.n_dim // 192), - double_conv=True, + kernel_sizes=[4, 4], + channel_step=48 * (self.n_dim // 192) * 2, + post_conv_num=3, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) @@ -107,7 +107,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if self.encode_vit: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) else: - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), nn.ReLU(inplace=True), @@ -116,8 +116,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[], - channel_step=self.vit_feat_dim, + conv_kernel_sizes=[3, 3], + channel_step=2*self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim+1, return_dist=False) diff --git a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml index a34d9cb..d9f7416 100644 --- a/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml @@ -7,8 +7,8 @@ world_model: rssm_dim: 512 slots_num: 6 slots_iter_num: 2 - kl_loss_scale: 1e2 + kl_loss_scale: 1.0 decode_vit: true - use_prev_slots: true + use_prev_slots: false vit_l2_ratio: 0.1 encode_vit: false diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index 9797857..0a5f68f 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -4,7 +4,7 @@ defaults: world_model: _target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel - rssm_dim: 1024 + rssm_dim: 768 slots_num: 4 slots_iter_num: 3 kl_loss_scale: 1.0 @@ -13,7 +13,7 @@ world_model: mask_combination: soft use_prev_slots: false per_slot_rec_loss: false - vit_l2_ratio: 0.1 + vit_l2_ratio: 0.5 full_qk_from: 4e4 symmetric_qk: false diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 052618e..9942360 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -7,21 +7,22 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:0 +device_type: cuda:1 agent: world_model: encode_vit: false - decode_vit: true + decode_vit: false vit_img_size: 224 vit_l2_ratio: 1.0 slots_iter_num: 3 - slots_num: 6 + slots_num: 4 kl_loss_scale: 2.0 + kl_loss_balancing: 0.8 kl_free_nats: 1.0 logger: - message: Attention, only dino, kl=20, removed symmetric, add warmup + message: Attention, without dino, kl=2, removed symmetric, add warmup, 4 slots, 768, log_grads: false training: @@ -36,8 +37,10 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator - #- _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + #- _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + _partial_: true + - _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator _partial_: true debug: From 582930422a620d4a34cb725a9b097c3fa600399c Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Tue, 1 Aug 2023 20:54:17 +0100 Subject: [PATCH 099/106] Train with f16, improve slot attention --- rl_sandbox/agents/dreamer/ac.py | 2 +- rl_sandbox/agents/dreamer/rssm.py | 13 ++------- .../agents/dreamer/rssm_slots_attention.py | 15 +++------- .../agents/dreamer/rssm_slots_combined.py | 13 ++------- rl_sandbox/agents/dreamer/world_model.py | 8 ++--- .../agents/dreamer/world_model_slots.py | 20 ++++++++----- .../dreamer/world_model_slots_attention.py | 24 +++++++-------- .../dreamer/world_model_slots_combined.py | 29 ++++++++++--------- rl_sandbox/agents/dreamer_v2.py | 13 +++++---- .../agent/dreamer_v2_slotted_combined.yaml | 6 ++++ rl_sandbox/config/config.yaml | 9 +++--- rl_sandbox/config/config_attention.yaml | 21 +++++++++----- rl_sandbox/utils/dists.py | 2 +- rl_sandbox/vision/slot_attention.py | 2 +- 14 files changed, 89 insertions(+), 88 deletions(-) diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py index bfd0a12..e7d8485 100644 --- a/rl_sandbox/agents/dreamer/ac.py +++ b/rl_sandbox/agents/dreamer/ac.py @@ -67,7 +67,7 @@ def lambda_return(self, zs, rs, ds): def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, discount_factors: torch.Tensor): - predicted_vs_dist = self.estimate_value(zs.detach()) + predicted_vs_dist = self.estimate_value(zs) losses = { 'loss_critic': -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(2) * diff --git a/rl_sandbox/agents/dreamer/rssm.py b/rl_sandbox/agents/dreamer/rssm.py index c28e046..ce8746e 100644 --- a/rl_sandbox/agents/dreamer/rssm.py +++ b/rl_sandbox/agents/dreamer/rssm.py @@ -143,15 +143,13 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret # Calculate stochastic state from prior embed # shared between all ensemble models - self.ensemble_prior_estimator = nn.ModuleList([ - nn.Sequential( + self.ensemble_prior_estimator = nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' norm_layer(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' - View((1, -1, latent_dim, self.latent_classes))) for _ in range(self.ensemble_num) - ]) + View((1, -1, latent_dim, self.latent_classes))) # For observation we do not have ensemble # FIXME: very bad magic number @@ -170,12 +168,7 @@ def __init__(self, latent_dim, hidden_size, actions_num, latent_classes, discret self.determ_layer_norm = nn.LayerNorm(hidden_size) def estimate_stochastic_latent(self, prev_determ: torch.Tensor): - dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] - # NOTE: Maybe something smarter can be used instead of - # taking only one random between all ensembles - # NOTE: in Dreamer ensemble_num is always 1 - idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[0] + return self.ensemble_prior_estimator(prev_determ) def on_train_step(self): pass diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index 582b8c2..d993028 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -120,16 +120,13 @@ def __init__(self, # Calculate stochastic state from prior embed # shared between all ensemble models - self.ensemble_prior_estimator = nn.ModuleList([ - nn.Sequential( + self.ensemble_prior_estimator = nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' norm_layer(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' View((1, -1, latent_dim, self.latent_classes))) - for _ in range(self.ensemble_num) - ]) # For observation we do not have ensemble img_sz = embed_size @@ -164,12 +161,7 @@ def on_train_step(self): self.attention_scheduler.step() def estimate_stochastic_latent(self, prev_determ: torch.Tensor): - dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] - # NOTE: Maybe something smarter can be used instead of - # taking only one random between all ensembles - # NOTE: in Dreamer ensemble_num is always 1 - idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[0] + return self.ensemble_prior_estimator(prev_determ) def predict_next(self, prev_state: State, action) -> State: x = self.pre_determ_recurrent( @@ -193,10 +185,11 @@ def predict_next(self, prev_state: State, action) -> State: # Experiment, when only stochastic part is affected and deterministic is not touched # We keep flow of gradients through determ block, but updating it with stochastic part for _ in range(self.attention_block_num): + # FIXME: Should the the prev stochastic component also be used ? q, k, v = self.hidden_attention_proj(self.pre_norm(determ_post)).chunk(3, dim=-1) if self.symmetric_qk: k = q - qk = torch.einsum('lbih,lbjh->lbij', q, k) + qk = torch.einsum('lbih,lbjh->lbij', q, k).float() attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) attn = attn / attn.sum(dim=-1, keepdim=True) diff --git a/rl_sandbox/agents/dreamer/rssm_slots_combined.py b/rl_sandbox/agents/dreamer/rssm_slots_combined.py index 39cda6c..19497b6 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_combined.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_combined.py @@ -42,6 +42,7 @@ def combined_slots(self): @property def stoch(self): if self.stoch_ is None: + self.stoch_logits = self.stoch_logits.to(dtype=torch.float32) self.stoch_ = Dist( self.stoch_logits).rsample().reshape(self.stoch_logits.shape[:3] + (-1, )) return self.stoch_ @@ -148,16 +149,13 @@ def __init__(self, # Calculate stochastic state from prior embed # shared between all ensemble models - self.ensemble_prior_estimator = nn.ModuleList([ - nn.Sequential( + self.ensemble_prior_estimator = nn.Sequential( nn.Linear(hidden_size, hidden_size), # Dreamer 'img_out_{k}' norm_layer(hidden_size), nn.ELU(inplace=True), nn.Linear(hidden_size, latent_dim * self.latent_classes), # Dreamer 'img_dist_{k}' View((1, -1, latent_dim, self.latent_classes))) - for _ in range(self.ensemble_num) - ]) img_sz = embed_size self.stoch_net = nn.Sequential( @@ -173,12 +171,7 @@ def on_train_step(self): pass def estimate_stochastic_latent(self, prev_determ: torch.Tensor): - dists_per_model = [model(prev_determ) for model in self.ensemble_prior_estimator] - # NOTE: Maybe something smarter can be used instead of - # taking only one random between all ensembles - # NOTE: in Dreamer ensemble_num is always 1 - idx = torch.randint(0, self.ensemble_num, ()) - return dists_per_model[0] + return self.ensemble_prior_estimator(prev_determ) def predict_next(self, prev_state: State, action) -> State: x = self.pre_determ_recurrent( diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 095c87c..8c84c08 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -69,7 +69,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[2], channel_step=384, - double_conv=False, flatten_output=False, in_channels=self.vit_feat_dim ) @@ -81,14 +80,12 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4, 4, 4], - channel_step=48, - double_conv=False) - + channel_step=48) if decode_vit: self.dino_predictor = Decoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3, 3], + conv_kernel_sizes=[3], channel_step=2*self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim, @@ -133,6 +130,7 @@ def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): def predict_next(self, prev_state: State, action): prior, _ = self.recurrent_model.predict_next(prev_state, action) + # FIXME: rewrite to utilize batch processing reward = self.reward_predictor(prior.combined).mode if self.predict_discount: discount_factors = self.discount_predictor(prior.combined).mode diff --git a/rl_sandbox/agents/dreamer/world_model_slots.py b/rl_sandbox/agents/dreamer/world_model_slots.py index 99a15c3..f51f6b3 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots.py +++ b/rl_sandbox/agents/dreamer/world_model_slots.py @@ -78,7 +78,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[2], channel_step=384, - double_conv=False, flatten_output=False, in_channels=self.vit_feat_dim ) @@ -89,14 +88,14 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - kernel_sizes=[4, 4, 4], - channel_step=48, - double_conv=True, + kernel_sizes=[4, 4], + channel_step=48 * (self.n_dim // 192) * 2, + post_conv_num=2, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) if self.encode_vit: - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) else: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) @@ -155,6 +154,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: (0.229, 0.224, 0.225)), tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) + else: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + obs = resize(obs) d_features = self.dino_vit(obs) return {'d_features': d_features} @@ -184,7 +186,7 @@ def predict_next(self, prev_state: State, action): reward = self.reward_predictor(prior.combined).mode if self.predict_discount: - discount_factors = self.discount_predictor(prior.combined).sample() + discount_factors = self.discount_predictor(prior.combined).mode else: discount_factors = torch.ones_like(reward) return prior, reward, discount_factors @@ -197,7 +199,11 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t state, prev_slots = state else: state, prev_slots = state[0], None - embed = self.encoder(obs.unsqueeze(0)) + if self.encode_vit: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + embed = self.encoder(resize(obs).unsqueeze(0)) + else: + embed = self.encoder(obs.unsqueeze(0)) embed_with_pos_enc = self.positional_augmenter_inp(embed) pre_slot_features_t = self.slot_mlp( diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index ad3c301..b5ccd02 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -83,13 +83,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if encode_vit: self.post_vit = nn.Sequential( View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), - Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - kernel_sizes=[2], - channel_step=384, - double_conv=False, - flatten_output=False, - in_channels=self.vit_feat_dim - ) ) self.encoder = nn.Sequential( self.dino_vit, @@ -99,13 +92,13 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, kernel_sizes=[4, 4], channel_step=48 * (self.n_dim // 192) * 2, - post_conv_num=3, + post_conv_num=2, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) if self.encode_vit: - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) else: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) @@ -116,7 +109,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3, 3], + conv_kernel_sizes=[3], channel_step=2*self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim+1, @@ -164,6 +157,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: (0.229, 0.224, 0.225)), tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) + else: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + obs = resize(obs) d_features = self.dino_vit(obs).squeeze() return {'d_features': d_features} @@ -194,7 +190,7 @@ def predict_next(self, prev_state: State, action): reward = self.reward_predictor(prior.combined).mode if self.predict_discount: - discount_factors = self.discount_predictor(prior.combined).sample() + discount_factors = self.discount_predictor(prior.combined).mode else: discount_factors = torch.ones_like(reward) return prior, reward, discount_factors @@ -207,7 +203,11 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t state, prev_slots = state else: state, prev_slots = state[0], None - embed = self.encoder(obs.unsqueeze(0)) + if self.encode_vit: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + embed = self.encoder(resize(obs).unsqueeze(0)) + else: + embed = self.encoder(obs.unsqueeze(0)) embed_with_pos_enc = self.positional_augmenter_inp(embed) pre_slot_features_t = self.slot_mlp( diff --git a/rl_sandbox/agents/dreamer/world_model_slots_combined.py b/rl_sandbox/agents/dreamer/world_model_slots_combined.py index 1a0ebe7..77b8729 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_combined.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_combined.py @@ -78,13 +78,6 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if encode_vit: self.post_vit = nn.Sequential( View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), - Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - kernel_sizes=[2], - channel_step=384, - double_conv=False, - flatten_output=False, - in_channels=self.vit_feat_dim - ) ) self.encoder = nn.Sequential( self.dino_vit, @@ -92,15 +85,15 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, ) else: self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - kernel_sizes=[4, 4, 4], - channel_step=48, - double_conv=True, + kernel_sizes=[4, 4], + channel_step=48 * (self.n_dim // 192) * 2, + post_conv_num=2, flatten_output=False) self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots) self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32)) if self.encode_vit: - self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4)) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14)) else: self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6)) @@ -111,8 +104,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, if decode_vit: self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[], - channel_step=self.vit_feat_dim, + conv_kernel_sizes=[3], + channel_step=2*self.vit_feat_dim, kernel_sizes=self.decoder_kernels, output_channels=self.vit_feat_dim+1, return_dist=False) @@ -159,6 +152,9 @@ def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: (0.229, 0.224, 0.225)), tv.transforms.Resize(self.vit_img_size, antialias=True)]) obs = ToTensor(obs + 0.5) + else: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + obs = resize(obs) d_features = self.dino_vit(obs).squeeze() return {'d_features': d_features} @@ -189,7 +185,7 @@ def predict_next(self, prev_state: State, action): reward = self.reward_predictor(prior.combined).mode if self.predict_discount: - discount_factors = self.discount_predictor(prior.combined).sample() + discount_factors = self.discount_predictor(prior.combined).mode else: discount_factors = torch.ones_like(reward) return prior, reward, discount_factors @@ -202,6 +198,11 @@ def get_latent(self, obs: torch.Tensor, action, state: t.Optional[tuple[State, t state, prev_slots = state else: state, prev_slots = state[0], None + if self.encode_vit: + resize = tv.transforms.Resize(self.vit_img_size, antialias=True) + embed = self.encoder(resize(obs).unsqueeze(0)) + else: + embed = self.encoder(obs.unsqueeze(0)) embed = self.encoder(obs.unsqueeze(0)) embed_with_pos_enc = self.positional_augmenter_inp(embed) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 7dcc9bd..b34d3fe 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -120,7 +120,6 @@ def preprocess_obs(self, obs: torch.Tensor): return ToTensor(obs.type(torch.float32).permute(order) / 255.0) else: return ((obs.type(torch.float32) / 255.0) - 0.5).permute(order) - # return obs.type(torch.float32).permute(order) def unprocess_obs(self, obs: torch.Tensor): order = list(range(len(obs.shape))) @@ -181,17 +180,21 @@ def train(self, rollout_chunks: RolloutChunks): initial_states = discovered_states.flatten().detach() states, actions, rewards, discount_factors = self.imagine_trajectory(initial_states) + + rewards = rewards.float() + discount_factors = discount_factors.float() + zs = states.combined rewards = self.world_model.reward_normalizer(rewards) + vs = self.critic.lambda_return(zs, rewards[:-1], discount_factors) + # Discounted factors should be shifted as they predict whether next state cannot be used # First discount factor on contrary is always 1 as it cannot lead to trajectory finish - discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-1]], dim=0).detach() - - vs = self.critic.lambda_return(zs, rewards[:-1], discount_factors) + discount_factors = torch.cat([torch.ones_like(discount_factors[:1]), discount_factors[:-1]], dim=0) # Ignore all factors after first is_finished state - discount_factors = torch.cumprod(discount_factors, dim=0) + discount_factors = torch.cumprod(discount_factors, dim=0).detach() losses_c, metrics_c = self.critic.calculate_loss(zs[:-1], vs, discount_factors[:-1]) diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml index 546b68e..82eff7a 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_combined.yaml @@ -14,3 +14,9 @@ world_model: use_prev_slots: false per_slot_rec_loss: false vit_l2_ratio: 0.1 + +wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 9fd23c3..0bafcc2 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -7,14 +7,14 @@ defaults: - override hydra/launcher: joblib seed: 43 -device_type: cuda:0 +device_type: cuda:1 agent: world_model: decode_vit: true vit_img_size: 224 - vit_l2_ratio: 0.5 - kl_loss_scale: 2.0 + vit_l2_ratio: 1.0 + kl_loss_scale: 3.0 kl_loss_balancing: 0.8 kl_free_nats: 1.0 @@ -25,10 +25,11 @@ agent: lr: 2e-4 logger: - message: Dreamer with 0.5 dino 0.8/2, ac 2xlr, 3-3 decoder at the end + message: Dreamer with only dino 0.8/3, fp16 log_grads: false training: + f16_precision: true checkpoint_path: null steps: 1e6 val_logs_every: 2e4 diff --git a/rl_sandbox/config/config_attention.yaml b/rl_sandbox/config/config_attention.yaml index 9942360..96b01e5 100644 --- a/rl_sandbox/config/config_attention.yaml +++ b/rl_sandbox/config/config_attention.yaml @@ -7,25 +7,32 @@ defaults: - override hydra/launcher: joblib seed: 42 -device_type: cuda:1 +device_type: cuda:0 agent: world_model: encode_vit: false - decode_vit: false + decode_vit: true vit_img_size: 224 vit_l2_ratio: 1.0 slots_iter_num: 3 slots_num: 4 - kl_loss_scale: 2.0 - kl_loss_balancing: 0.8 + kl_loss_scale: 3.0 + kl_loss_balancing: 0.6 kl_free_nats: 1.0 + actor_optim: + lr: 1e-4 + + critic_optim: + lr: 1e-4 + logger: - message: Attention, without dino, kl=2, removed symmetric, add warmup, 4 slots, 768, + message: Attention, only dino, kl=0.6/3, 14x14, 768 rssm, no fp16, reverse dino log_grads: false training: + f16_precision: false checkpoint_path: null steps: 1e6 val_logs_every: 2e4 @@ -37,8 +44,8 @@ validation: - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator log_video: True _partial_: true - #- _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator - - _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator + - _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator + #- _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator _partial_: true - _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator _partial_: true diff --git a/rl_sandbox/utils/dists.py b/rl_sandbox/utils/dists.py index 532bbf1..e119c5d 100644 --- a/rl_sandbox/utils/dists.py +++ b/rl_sandbox/utils/dists.py @@ -202,7 +202,7 @@ def get_trunc_normal(x, min_std=0.1): return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float(), a=-1, b=1) self.dist = get_trunc_normal case 'binary': - self.dist = lambda x: td.Bernoulli(logits=x) + self.dist = lambda x: td.Bernoulli(logits=x.float()) case _: raise RuntimeError("Invalid dist layer") diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index a23fec6..05816d3 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -66,7 +66,7 @@ def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optio slots = self.slots_norm(slots) q = self.slots_proj(slots) - attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k), dim=1) + self.epsilon + attn = F.softmax(self.scale*torch.einsum('bik,bjk->bij', q, k).float(), dim=1) + self.epsilon attn = attn / attn.sum(dim=-1, keepdim=True) self.last_attention = attn From 384b253b24ea028be12d8228866c4f395914e98c Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Fri, 11 Aug 2023 18:19:10 +0100 Subject: [PATCH 100/106] Pytorch 2.0 support, Spatial Broadcast Decoder and fixes --- rl_sandbox/agents/dreamer/ac.py | 2 +- .../agents/dreamer/rssm_slots_attention.py | 6 +- rl_sandbox/agents/dreamer/vision.py | 55 +++++++++++++++++- rl_sandbox/agents/dreamer/world_model.py | 11 ++-- rl_sandbox/agents/dreamer_v2.py | 6 +- rl_sandbox/utils/dists.py | 56 ++++++++----------- rl_sandbox/utils/replay_buffer.py | 2 +- rl_sandbox/utils/rollout_generation.py | 2 +- rl_sandbox/vision/dino.py | 12 ++-- rl_sandbox/vision/slot_attention.py | 18 +++--- 10 files changed, 107 insertions(+), 63 deletions(-) diff --git a/rl_sandbox/agents/dreamer/ac.py b/rl_sandbox/agents/dreamer/ac.py index e7d8485..bfd0a12 100644 --- a/rl_sandbox/agents/dreamer/ac.py +++ b/rl_sandbox/agents/dreamer/ac.py @@ -67,7 +67,7 @@ def lambda_return(self, zs, rs, ds): def calculate_loss(self, zs: torch.Tensor, vs: torch.Tensor, discount_factors: torch.Tensor): - predicted_vs_dist = self.estimate_value(zs) + predicted_vs_dist = self.estimate_value(zs.detach()) losses = { 'loss_critic': -(predicted_vs_dist.log_prob(vs.detach()).unsqueeze(2) * diff --git a/rl_sandbox/agents/dreamer/rssm_slots_attention.py b/rl_sandbox/agents/dreamer/rssm_slots_attention.py index d993028..6d10c06 100644 --- a/rl_sandbox/agents/dreamer/rssm_slots_attention.py +++ b/rl_sandbox/agents/dreamer/rssm_slots_attention.py @@ -139,7 +139,7 @@ def __init__(self, latent_dim * self.latent_classes), # Dreamer 'obs_dist' View((1, -1, latent_dim, self.latent_classes))) - self.hidden_attention_proj = nn.Linear(hidden_size, 3*hidden_size) + self.hidden_attention_proj = nn.Linear(hidden_size, 3*hidden_size, bias=False) self.pre_norm = nn.LayerNorm(hidden_size) self.fc = nn.Linear(hidden_size, hidden_size) @@ -191,13 +191,13 @@ def predict_next(self, prev_state: State, action) -> State: k = q qk = torch.einsum('lbih,lbjh->lbij', q, k).float() - attn = torch.softmax(self.att_scale * qk + self.eps, dim=-1) + attn = torch.softmax(self.att_scale * qk, dim=-1) + self.eps attn = attn / attn.sum(dim=-1, keepdim=True) coeff = self.attention_scheduler.val attn = coeff * attn + (1 - coeff) * torch.eye(q.shape[-2],device=q.device) - updates = torch.einsum('lbij,lbjh->lbih', attn, v) + updates = torch.einsum('lbjd,lbij->lbid', v, attn) determ_post = determ_post + self.fc(self.fc_norm(updates)) self.last_attention = attn.mean(dim=1).squeeze() diff --git a/rl_sandbox/agents/dreamer/vision.py b/rl_sandbox/agents/dreamer/vision.py index 8435b95..0b706c5 100644 --- a/rl_sandbox/agents/dreamer/vision.py +++ b/rl_sandbox/agents/dreamer/vision.py @@ -1,5 +1,7 @@ import torch.distributions as td from torch import nn +import torch +from rl_sandbox.vision.slot_attention import PositionalEmbedding class Encoder(nn.Module): @@ -23,7 +25,7 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity, for k in range(post_conv_num): layers.append( - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')) + nn.Conv2d(out_channels, out_channels, kernel_size=5, padding='same')) layers.append(norm_layer(1, out_channels)) layers.append(nn.ELU(inplace=True)) @@ -35,6 +37,57 @@ def forward(self, X): return self.net(X) +class SpatialBroadcastDecoder(nn.Module): + + def __init__(self, + input_size, + norm_layer: nn.GroupNorm | nn.Identity, + kernel_sizes = [3, 3, 3], + out_image=(64, 64), + channel_step=64, + output_channels=3, + return_dist=True): + + super().__init__() + layers = [] + self.channel_step = channel_step + self.in_channels = 2*self.channel_step + self.out_shape = out_image + self.positional_augmenter = PositionalEmbedding(self.in_channels, out_image) + + in_channels = self.in_channels + self.convin = nn.Linear(input_size, in_channels) + self.return_dist = return_dist + + for i, k in enumerate(kernel_sizes): + out_channels = channel_step + if i == len(kernel_sizes) - 1: + out_channels = output_channels + layers.append(nn.Conv2d(in_channels, + out_channels, + kernel_size=k, + padding='same')) + else: + layers.append(nn.Conv2d(in_channels, + out_channels, + kernel_size=k, + padding='same')) + layers.append(norm_layer(1, out_channels)) + layers.append(nn.ELU(inplace=True)) + in_channels = out_channels + + self.net = nn.Sequential(*layers) + + def forward(self, X): + x = self.convin(X) + x = x.view(-1, self.in_channels, 1, 1) + x = torch.tile(x, self.out_shape) + x = self.positional_augmenter(x) + if self.return_dist: + return td.Independent(td.Normal(self.net(x), 1.0), 3) + else: + return self.net(x) + class Decoder(nn.Module): def __init__(self, diff --git a/rl_sandbox/agents/dreamer/world_model.py b/rl_sandbox/agents/dreamer/world_model.py index 8c84c08..a8088ab 100644 --- a/rl_sandbox/agents/dreamer/world_model.py +++ b/rl_sandbox/agents/dreamer/world_model.py @@ -8,7 +8,7 @@ from rl_sandbox.agents.dreamer import Dist, Normalizer, View from rl_sandbox.agents.dreamer.rssm import RSSM, State -from rl_sandbox.agents.dreamer.vision import Decoder, Encoder +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, SpatialBroadcastDecoder from rl_sandbox.utils.dists import DistLayer from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.vision.dino import ViTFeat @@ -83,13 +83,14 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, channel_step=48) if decode_vit: - self.dino_predictor = Decoder(self.state_size, + self.dino_predictor = SpatialBroadcastDecoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3], - channel_step=2*self.vit_feat_dim, - kernel_sizes=self.decoder_kernels, + out_image=(14, 14), + kernel_sizes = [5, 5, 5, 5], + channel_step=self.vit_feat_dim, output_channels=self.vit_feat_dim, return_dist=True) + self.image_predictor = Decoder(self.state_size, norm_layer=nn.GroupNorm if layer_norm else nn.Identity) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index b34d3fe..3c61120 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -51,11 +51,11 @@ def __init__( raise RuntimeError('Invalid reward clipping') self.is_f16 = f16_precision - self.world_model: WorldModel = world_model(actions_num=actions_num).to(device_type) + self.world_model: WorldModel = torch.compile(world_model(actions_num=actions_num), mode='max-autotune').to(device_type) self.actor: ImaginativeActor = actor(latent_dim=self.world_model.state_size, actions_num=actions_num, is_discrete=self.is_discrete).to(device_type) - self.critic: ImaginativeCritic = critic(latent_dim=self.world_model.state_size).to(device_type) + self.critic: ImaginativeCritic = torch.compile(critic(latent_dim=self.world_model.state_size), mode='max-autotune').to(device_type) self.world_model_optimizer = wm_optim(model=self.world_model, scaler=self.is_f16) if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: @@ -151,7 +151,7 @@ def get_action(self, obs: Observation) -> Action: if self.is_discrete: return self._last_action.argmax() else: - return self._last_action.squeeze().detach().cpu().numpy() + return self._last_action.squeeze().detach().cpu() def from_np(self, arr: np.ndarray): arr = torch.from_numpy(arr) if isinstance(arr, np.ndarray) else arr diff --git a/rl_sandbox/utils/dists.py b/rl_sandbox/utils/dists.py index e119c5d..759fc3c 100644 --- a/rl_sandbox/utils/dists.py +++ b/rl_sandbox/utils/dists.py @@ -1,5 +1,3 @@ -# Taken from https://raw.githubusercontent.com/toshas/torch_truncnorm/main/TruncatedNormal.py -# Added torch modules on top import math from numbers import Number import typing as t @@ -10,6 +8,7 @@ from torch import nn from torch.distributions import Distribution, constraints from torch.distributions.utils import broadcast_all +from torch.distributions.utils import _standard_normal CONST_SQRT_2 = math.sqrt(2) CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) @@ -106,39 +105,28 @@ def rsample(self, sample_shape=torch.Size()): p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1) return self.icdf(p) +class TruncatedNormal(td.Normal): + def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): + super().__init__(loc, scale, validate_args=False) + self.low = low + self.high = high + self.eps = eps -class TruncatedNormal(TruncatedStandardNormal): - """ - Truncated Normal distribution - https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - """ - - has_rsample = True - - def __init__(self, loc, scale, a, b, validate_args=None): - self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) - a = (a - self.loc) / self.scale - b = (b - self.loc) / self.scale - super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) - self._log_scale = self.scale.log() - self._mean = self._mean * self.scale + self.loc - self._variance = self._variance * self.scale ** 2 - self._entropy += self._log_scale - - def _to_std_rv(self, value): - return (value - self.loc) / self.scale + def _clamp(self, x): + clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) + x = x - x.detach() + clamped_x.detach() + return x - def _from_std_rv(self, value): - return value * self.scale + self.loc - - def cdf(self, value): - return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) - - def icdf(self, value): - return self._from_std_rv(super(TruncatedNormal, self).icdf(value)) - - def log_prob(self, value): - return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale + def sample(self, sample_shape=torch.Size(), clip=None): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, + dtype=self.loc.dtype, + device=self.loc.device) + eps *= self.scale + if clip is not None: + eps = torch.clamp(eps, -clip, clip) + x = self.loc + eps + return self._clamp(x) class Sigmoid2(nn.Module): @@ -199,7 +187,7 @@ def get_tanh_normal(x, min_std=0.1): case 'normal_trunc': def get_trunc_normal(x, min_std=0.1): mean, std = x.chunk(2, dim=-1) - return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float(), a=-1, b=1) + return TruncatedNormal(loc=torch.tanh(mean).float(), scale=(2*torch.sigmoid(std/2) + min_std).float()) self.dist = get_trunc_normal case 'binary': self.dist = lambda x: td.Bernoulli(logits=x.float()) diff --git a/rl_sandbox/utils/replay_buffer.py b/rl_sandbox/utils/replay_buffer.py index 806f5f1..e916ef9 100644 --- a/rl_sandbox/utils/replay_buffer.py +++ b/rl_sandbox/utils/replay_buffer.py @@ -107,7 +107,7 @@ def add_sample(self, env_step: EnvStep): self.add_rollout( Rollout( torch.stack(self.curr_rollout.obs), - torch.stack(self.curr_rollout.actions).reshape(-1, 1), + torch.stack(self.curr_rollout.actions).reshape(len(self.curr_rollout.actions), -1), torch.Tensor(self.curr_rollout.rewards), torch.Tensor(self.curr_rollout.is_finished), torch.Tensor(self.curr_rollout.is_first), diff --git a/rl_sandbox/utils/rollout_generation.py b/rl_sandbox/utils/rollout_generation.py index 7ae5423..c08012c 100644 --- a/rl_sandbox/utils/rollout_generation.py +++ b/rl_sandbox/utils/rollout_generation.py @@ -99,7 +99,7 @@ def collect_rollout(env: Env, for k, v in add.items(): additional[k].append(v) - return Rollout(torch.stack(s), torch.stack(a).reshape(-1, 1), + return Rollout(torch.stack(s), torch.stack(a).reshape(len(a), -1), torch.Tensor(r).float(), torch.Tensor(t), torch.Tensor(f), {k: torch.stack(v) for k, v in additional.items()}) diff --git a/rl_sandbox/vision/dino.py b/rl_sandbox/vision/dino.py index 87f36c1..37a3bc6 100644 --- a/rl_sandbox/vision/dino.py +++ b/rl_sandbox/vision/dino.py @@ -121,7 +121,8 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + feat_qkv = self.qkv(x) + qkv = feat_qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale @@ -131,7 +132,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) - return x, attn + return x, (attn, feat_qkv) class Block(nn.Module): @@ -316,10 +317,6 @@ def __init__(self, pretrained_pth, feat_dim, vit_arch = 'base', vit_feat = 'k', def forward(self, img) : feat_out = {} - def hook_fn_forward_qkv(module, input, output): - feat_out["qkv"] = output - - self.model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) # Forward pass in the model @@ -327,9 +324,10 @@ def hook_fn_forward_qkv(module, input, output): h, w = img.shape[2], img.shape[3] feat_h, feat_w = h // self.patch_size, w // self.patch_size attentions = self.model.get_last_selfattention(img) + attentions, feat_qkv = attentions bs, nb_head, nb_token = attentions.shape[0], attentions.shape[1], attentions.shape[2] qkv = ( - feat_out["qkv"] + feat_qkv .reshape(bs, nb_token, 3, nb_head, -1) .permute(2, 0, 3, 1, 4) ) diff --git a/rl_sandbox/vision/slot_attention.py b/rl_sandbox/vision/slot_attention.py index 05816d3..21975ea 100644 --- a/rl_sandbox/vision/slot_attention.py +++ b/rl_sandbox/vision/slot_attention.py @@ -29,17 +29,17 @@ def __init__(self, num_slots: int, n_dim: int, n_iter: int, use_prev_slots: bool self.slots_logsigma = nn.Parameter(torch.zeros(1, num_slots, self.n_dim)) nn.init.xavier_uniform_(self.slots_logsigma) - self.slots_proj = nn.Linear(n_dim, n_dim) + self.slots_proj = nn.Linear(n_dim, n_dim, bias=False) self.slots_proj_2 = nn.Sequential( - nn.Linear(n_dim, n_dim*2), + nn.Linear(n_dim, n_dim*4), nn.ReLU(inplace=True), - nn.Linear(n_dim*2, n_dim), + nn.Linear(n_dim*4, n_dim), ) self.slots_norm = nn.LayerNorm(self.n_dim) self.slots_norm_2 = nn.LayerNorm(self.n_dim) self.slots_reccur = nn.GRUCell(input_size=self.n_dim, hidden_size=self.n_dim) - self.inputs_proj = nn.Linear(n_dim, n_dim*2) + self.inputs_proj = nn.Linear(n_dim, n_dim*2, bias=False) self.inputs_norm = nn.LayerNorm(self.n_dim) self.prev_slots = None @@ -71,7 +71,7 @@ def forward(self, X: Float[torch.Tensor, 'batch seq n_dim'], prev_slots: t.Optio self.last_attention = attn - updates = torch.einsum('bij,bjk->bik', attn, v) / self.n_slots + updates = torch.einsum('bjd,bij->bid', v, attn) slots = self.slots_reccur(updates.reshape(-1, self.n_dim), slots_prev.reshape(-1, self.n_dim)).reshape(batch, self.n_slots, self.n_dim) slots = slots + self.slots_proj_2(self.slots_norm_2(slots)) return slots @@ -87,14 +87,18 @@ def build_grid(resolution): class PositionalEmbedding(nn.Module): - def __init__(self, n_dim: int, res: t.Tuple[int, int]): + def __init__(self, n_dim: int, res: t.Tuple[int, int], channel_last=False): super().__init__() self.n_dim = n_dim self.proj = nn.Linear(4, n_dim) + self.channel_last = channel_last self.register_buffer('grid', torch.from_numpy(build_grid(res))) def forward(self, X) -> torch.Tensor: - return X + self.proj(self.grid).permute(0, 3, 1, 2) + if self.channel_last: + return X + self.proj(self.grid) + else: + return X + self.proj(self.grid).permute(0, 3, 1, 2) class SlottedAutoEncoder(nn.Module): def __init__(self, num_slots: int, n_iter: int, dino_inp_size: int = 224): From 86f0ba2c2a0c6a8e0e9eec2af0d45ed4da54fe81 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Fri, 11 Aug 2023 18:22:55 +0100 Subject: [PATCH 101/106] Added slot implementation after the dynamics model --- .../agents/dreamer/world_model_post_slot.py | 316 ++++++++++++++++++ rl_sandbox/config/config_postslot.yaml | 71 ++++ rl_sandbox/config/env/dm_finger_spin.yaml | 8 + rl_sandbox/metrics.py | 64 ++++ 4 files changed, 459 insertions(+) create mode 100644 rl_sandbox/agents/dreamer/world_model_post_slot.py create mode 100644 rl_sandbox/config/config_postslot.yaml create mode 100644 rl_sandbox/config/env/dm_finger_spin.yaml diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py new file mode 100644 index 0000000..4e0c41a --- /dev/null +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -0,0 +1,316 @@ +import typing as t + +import torch +import torch.distributions as td +import torchvision as tv +from torch import nn +from torch.nn import functional as F + +from rl_sandbox.agents.dreamer import Dist, Normalizer, View +from rl_sandbox.agents.dreamer.rssm import RSSM, State +from rl_sandbox.agents.dreamer.vision import Decoder, Encoder, SpatialBroadcastDecoder +from rl_sandbox.utils.dists import DistLayer +from rl_sandbox.utils.fc_nn import fc_nn_generator +from rl_sandbox.vision.dino import ViTFeat +from rl_sandbox.vision.slot_attention import PositionalEmbedding, SlotAttention + +class WorldModel(nn.Module): + + def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, + actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, + predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, + vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, mask_combination: str): + super().__init__() + self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) + self.discount_scale = discount_loss_scale + self.kl_beta = kl_loss_scale + + self.rssm_dim = rssm_dim + self.latent_dim = latent_dim + self.latent_classes = latent_classes + self.state_size = (rssm_dim + latent_dim * latent_classes) + + self.cluster_size = batch_cluster_size + self.actions_num = actions_num + # kl loss balancing (prior/posterior) + self.alpha = kl_loss_balancing + self.predict_discount = predict_discount + self.encode_vit = encode_vit + self.decode_vit = decode_vit + self.vit_l2_ratio = vit_l2_ratio + self.vit_img_size = vit_img_size + + self.recurrent_model = RSSM(latent_dim, + rssm_dim, + actions_num, + latent_classes, + discrete_rssm, + norm_layer=nn.LayerNorm if layer_norm else nn.Identity) + if encode_vit or decode_vit: + if self.vit_img_size == 224: + self.dino_vit = ViTFeat("/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=16) + self.decoder_kernels = [3, 3, 2] + self.vit_size = 14 + elif self.vit_img_size == 64: + self.dino_vit = ViTFeat("/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", + feat_dim=384, vit_arch='small', patch_size=8) + self.decoder_kernels = [3, 4] + self.vit_size = 8 + else: + raise RuntimeError("Unknown vit img size") + # self.dino_vit = ViTFeat("/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", feat_dim=768, vit_arch='base', patch_size=8) + self.vit_feat_dim = self.dino_vit.feat_dim + self.dino_vit.requires_grad_(False) + + if encode_vit: + self.post_vit = nn.Sequential( + View((-1, self.vit_feat_dim, self.vit_size, self.vit_size)), + Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[2], + channel_step=384, + flatten_output=False, + in_channels=self.vit_feat_dim + ) + ) + self.encoder = nn.Sequential( + self.dino_vit, + self.post_vit + ) + else: + self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + kernel_sizes=[4, 4, 4, 4], + channel_step=48) + + self.n_dim = 256 + + if decode_vit: + # self.dino_predictor = SpatialBroadcastDecoder(self.n_dim, + # norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + # out_image=(14, 14), + # kernel_sizes = [5, 5, 5], + # channel_step=self.vit_feat_dim, + # output_channels=self.vit_feat_dim+1, + # return_dist=False) + self.dino_predictor = Decoder(self.n_dim, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) + + self.slots_num = slots_num + self.mask_combination = mask_combination + self.state_feature_num = (self.state_size//self.n_dim) + self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (1, self.state_feature_num), channel_last=True) + # TODO: slots will assume permutation-invariance + self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, True) + self.state_reshuffle = nn.Sequential(nn.Linear(self.state_size, self.state_feature_num*self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.state_feature_num*self.n_dim, self.state_feature_num*self.n_dim)) + + # self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), + # nn.ReLU(inplace=True), + # nn.Linear(self.n_dim, self.n_dim)) + + self.image_predictor = Decoder(self.n_dim, + output_channels=4, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + return_dist=False) + + self.reward_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('mse')) + self.discount_predictor = fc_nn_generator(self.state_size, + 1, + hidden_size=400, + num_layers=5, + intermediate_activation=nn.ELU, + layer_norm=layer_norm, + final_activation=DistLayer('binary')) + self.reward_normalizer = Normalizer(momentum=1.00, scale=1.0, eps=1e-8) + + def slot_mask(self, masks: torch.Tensor) -> torch.Tensor: + match self.mask_combination: + case 'soft': + img_mask = F.softmax(masks, dim=-4) + case 'hard': + probs = F.softmax(masks - masks.logsumexp(dim=1,keepdim=True), dim=1) + img_mask = F.one_hot(masks.argmax(dim=1), num_classes=masks.shape[1]).permute(0, 4, 1, 2, 3) + (probs - probs.detach()) + case 'qmix': + raise NotImplementedError + case _: + raise NotImplementedError + return img_mask + + def precalc_data(self, obs: torch.Tensor) -> dict[str, torch.Tensor]: + if not self.decode_vit: + return {} + if not self.encode_vit: + ToTensor = tv.transforms.Compose([tv.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)), + tv.transforms.Resize(self.vit_img_size, antialias=True)]) + obs = ToTensor(obs + 0.5) + with torch.no_grad(): + d_features = self.dino_vit(obs).cpu() + return {'d_features': d_features} + + def get_initial_state(self, batch_size: int = 1, seq_size: int = 1): + device = next(self.parameters()).device + return State(torch.zeros(seq_size, batch_size, self.rssm_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes, self.latent_dim, device=device), + torch.zeros(seq_size, batch_size, self.latent_classes * self.latent_dim, device=device)) + + def predict_next(self, prev_state: State, action): + prior, _ = self.recurrent_model.predict_next(prev_state, action) + + # FIXME: rewrite to utilize batch processing + reward = self.reward_predictor(prior.combined).mode + if self.predict_discount: + discount_factors = self.discount_predictor(prior.combined).mode + else: + discount_factors = torch.ones_like(reward) + return prior, reward, discount_factors + + def get_latent(self, obs: torch.Tensor, action, state: t.Optional[State]) -> State: + if isinstance(state, tuple): + state = state[0] + if state is None: + state = self.get_initial_state() + embed = self.encoder(obs.unsqueeze(0)) + _, posterior, _ = self.recurrent_model.forward(state, embed.unsqueeze(0), + action) + return posterior, None + + def calculate_loss(self, obs: torch.Tensor, a: torch.Tensor, r: torch.Tensor, + discount: torch.Tensor, first: torch.Tensor, additional: dict[str, torch.Tensor]): + self.recurrent_model.on_train_step() + b, _, h, w = obs.shape # s <- BxHxWx3 + + if self.encode_vit: + embed = self.post_vit(additional['d_features']) + else: + embed = self.encoder(obs) + embed_c = embed.reshape(b // self.cluster_size, self.cluster_size, -1) + + a_c = a.reshape(-1, self.cluster_size, self.actions_num) + r_c = r.reshape(-1, self.cluster_size, 1) + d_c = discount.reshape(-1, self.cluster_size, 1) + first_c = first.reshape(-1, self.cluster_size, 1) + + losses = {} + metrics = {} + + def KL(dist1, dist2, free_nat = True): + KL_ = torch.distributions.kl_divergence + one = self.kl_free_nats * torch.ones(1, device=next(self.parameters()).device) + # TODO: kl_free_avg is used always + if free_nat: + kl_lhs = torch.maximum(KL_(Dist(dist2.detach()), Dist(dist1)).mean(), one) + kl_rhs = torch.maximum(KL_(Dist(dist2), Dist(dist1.detach())).mean(), one) + else: + kl_lhs = KL_(Dist(dist2.detach()), Dist(dist1)).mean() + kl_rhs = KL_(Dist(dist2), Dist(dist1.detach())).mean() + return ((self.alpha * kl_lhs + (1 - self.alpha) * kl_rhs)) + + priors = [] + posteriors = [] + + if self.decode_vit: + d_features = additional['d_features'] + + prev_state = self.get_initial_state(b // self.cluster_size) + for t in range(self.cluster_size): + # s_t <- 1xB^xHxWx3 + embed_t, a_t, first_t = embed_c[:, t].unsqueeze(0), a_c[:, t].unsqueeze(0), first_c[:, t].unsqueeze(0) + a_t = a_t * (1 - first_t) + + prior, posterior, diff = self.recurrent_model.forward(prev_state, embed_t, a_t) + prev_state = posterior + + priors.append(prior) + posteriors.append(posterior) + + # losses['loss_determ_recons'] += diff + + posterior = State.stack(posteriors) + prior = State.stack(priors) + + state = self.state_reshuffle(posterior.combined.transpose(0, 1)) + state = state.reshape(*state.shape[:-1], self.state_feature_num, self.n_dim) + state_pos_embedded = self.positional_augmenter_inp(state.unsqueeze(-3)).squeeze(-3) + + state_slots = self.slot_attention(state_pos_embedded.flatten(0, 1), None) + + r_pred = self.reward_predictor(posterior.combined.transpose(0, 1)) + f_pred = self.discount_predictor(posterior.combined.transpose(0, 1)) + + losses['loss_reconstruction_img'] = torch.Tensor([0]).to(obs.device) + + if not self.decode_vit: + # x_r = self.image_predictor(posterior.combined.transpose(0, 1).flatten(0, 1)) + # losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() + decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean(dim=0) + losses['loss_reconstruction'] = img_rec + else: + if self.vit_l2_ratio == 1.0: + decoded_imgs_detached, masks = self.image_predictor(state_slots.flatten(0, 1).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + img_mask = self.slot_mask(masks) + decoded_imgs_detached = decoded_imgs_detached * img_mask + + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) + + img_rec = torch.tensor(0, device=obs.device) + img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + losses['loss_reconstruction_img'] = img_rec_detached + else: + decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + img_mask = self.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) + img_rec = -x_r.log_prob(obs).float() + + decoded_feats, masks = self.dino_predictor(state_slots.flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=-3) + feat_mask = self.slot_mask(masks) + + d_obs = d_features.reshape(b, self.vit_feat_dim, self.vit_size, self.vit_size) + + decoded_feats = decoded_feats * feat_mask + + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=-4), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() + + d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) + + losses['loss_reconstruction'] = (self.vit_l2_ratio * d_rec + (1-self.vit_l2_ratio) * img_rec) + metrics['loss_l2_rec'] = img_rec + metrics['loss_dino_rec'] = d_rec + + prior_logits = prior.stoch_logits + posterior_logits = posterior.stoch_logits + losses['loss_reward_pred'] = -r_pred.log_prob(r_c).float().mean() + losses['loss_discount_pred'] = -f_pred.log_prob(d_c).float().mean() + losses['loss_kl_reg'] = KL(prior_logits, posterior_logits) + + metrics['reward_mean'] = r.mean() + metrics['reward_std'] = r.std() + metrics['reward_sae'] = (torch.abs(r_pred.mode - r_c)).mean() + metrics['prior_entropy'] = Dist(prior_logits).entropy().mean() + metrics['posterior_entropy'] = Dist(posterior_logits).entropy().mean() + + losses['loss_wm'] = (losses['loss_reconstruction'] + losses['loss_reward_pred'] + + self.kl_beta * losses['loss_kl_reg'] + self.discount_scale*losses['loss_discount_pred']) + + return losses, posterior, metrics diff --git a/rl_sandbox/config/config_postslot.yaml b/rl_sandbox/config/config_postslot.yaml new file mode 100644 index 0000000..d62809b --- /dev/null +++ b/rl_sandbox/config/config_postslot.yaml @@ -0,0 +1,71 @@ +defaults: + - agent: dreamer_v2 + - env: dm_finger_spin + - training: dm + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda:0 + +agent: + world_model: + _target_: rl_sandbox.agents.dreamer.world_model_post_slot.WorldModel + rssm_dim: 256 + slots_num: 5 + slots_iter_num: 3 + + encode_vit: false + decode_vit: true + mask_combination: soft + vit_l2_ratio: 1.0 + + vit_img_size: 224 + kl_loss_scale: 1.0 + kl_loss_balancing: 0.8 + kl_free_nats: 1.0 + + wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + +logger: + message: Post-wm slot attention, with dino, finger spin, n_dim=256, state_feature=5, slots=5 + log_grads: false + +training: + f16_precision: false + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.PostSlottedDreamerMetricsEvaluator + _partial_: true + #- _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator + # _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 1 + # n_jobs: 8 + #sweeper: + # params: + # agent.world_model.kl_loss_scale: 1e-3,5e-3,1e-2,5e-2,5,10,15,20,25,50,75,100,250 + # agent.world_model.full_qk_from: 1,2e4 + # agent.world_model.symmetric_qk: true,false + # agent.world_model.attention_block_num: 1,3 diff --git a/rl_sandbox/config/env/dm_finger_spin.yaml b/rl_sandbox/config/env/dm_finger_spin.yaml new file mode 100644 index 0000000..a4b8f9f --- /dev/null +++ b/rl_sandbox/config/env/dm_finger_spin.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: finger +task_name: spin +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 0517ee0..981b1df 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -377,3 +377,67 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_scalar('val/img_reward_err', rewards_err.item(), epoch_num) logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) + +class PostSlottedDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze() + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + rews = [] + + state = None + prev_slots = None + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode + + wm_state = self.agent.world_model.state_reshuffle(state.combined) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + decoded_imgs, masks = self.agent.world_model.image_predictor(wm_state_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + wm_state = self.agent.world_model.state_reshuffle(states.combined[1:]) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + # video_r = self.agent.world_model.image_predictor(states.combined_slots[1:]).mode + decoded_imgs, masks = self.agent.world_model.image_predictor(wm_state_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, 4, 64, 64).split([3, 1], dim=2) + img_mask = self.agent.world_model.slot_mask(masks) + decoded_imgs = decoded_imgs * img_mask + video_r = torch.sum(decoded_imgs, dim=1) + + video.append(self.agent.unprocess_obs(video_r)) + slots_video.append(self.agent.unprocess_obs(decoded_imgs)) + + return torch.cat(video), rews, torch.cat(slots_video) + From 83690182c9292e762082ca3c0781c202676a1325 Mon Sep 17 00:00:00 2001 From: Roman Milishchuk Date: Sun, 13 Aug 2023 19:48:45 +0100 Subject: [PATCH 102/106] Added per slot rec loss --- .../agents/dreamer/world_model_post_slot.py | 66 +++++++++++++++---- rl_sandbox/config/config_postslot.yaml | 7 +- rl_sandbox/metrics.py | 8 ++- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py index 4e0c41a..14c7734 100644 --- a/rl_sandbox/agents/dreamer/world_model_post_slot.py +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -19,11 +19,13 @@ class WorldModel(nn.Module): def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, - vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, mask_combination: str): + vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, + mask_combination: str, use_reshuffle: bool, per_slot_rec_loss: bool): super().__init__() self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.discount_scale = discount_loss_scale self.kl_beta = kl_loss_scale + self.per_slot_rec_loss = per_slot_rec_loss self.rssm_dim = rssm_dim self.latent_dim = latent_dim @@ -106,9 +108,11 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (1, self.state_feature_num), channel_last=True) # TODO: slots will assume permutation-invariance self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, True) - self.state_reshuffle = nn.Sequential(nn.Linear(self.state_size, self.state_feature_num*self.n_dim), - nn.ReLU(inplace=True), - nn.Linear(self.state_feature_num*self.n_dim, self.state_feature_num*self.n_dim)) + self.use_reshuffle = use_reshuffle + if self.use_reshuffle: + self.state_reshuffle = nn.Sequential(nn.Linear(self.state_size, self.state_feature_num*self.n_dim), + nn.ReLU(inplace=True), + nn.Linear(self.state_feature_num*self.n_dim, self.state_feature_num*self.n_dim)) # self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim), # nn.ReLU(inplace=True), @@ -241,7 +245,11 @@ def KL(dist1, dist2, free_nat = True): posterior = State.stack(posteriors) prior = State.stack(priors) - state = self.state_reshuffle(posterior.combined.transpose(0, 1)) + if self.use_reshuffle: + state = self.state_reshuffle(posterior.combined.transpose(0, 1)) + else: + state = posterior.combined.transpose(0, 1) + assert state.shape[-1] % self.n_dim == 0 and self.rssm_dim % self.n_dim == 0 state = state.reshape(*state.shape[:-1], self.state_feature_num, self.n_dim) state_pos_embedded = self.positional_augmenter_inp(state.unsqueeze(-3)).squeeze(-3) @@ -259,8 +267,15 @@ def KL(dist1, dist2, free_nat = True): img_mask = self.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask - x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) - img_rec = -x_r.log_prob(obs).float().mean(dim=0) + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec = l2_loss.mean() * normalizing_factor * 8 + img_rec = img_rec.mean() + else: + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean(dim=0) losses['loss_reconstruction'] = img_rec else: if self.vit_l2_ratio == 1.0: @@ -268,10 +283,18 @@ def KL(dist1, dist2, free_nat = True): img_mask = self.slot_mask(masks) decoded_imgs_detached = decoded_imgs_detached * img_mask - x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) - img_rec = torch.tensor(0, device=obs.device) - img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec_detached = l2_loss.mean() * normalizing_factor * 8 + img_rec_detached = img_rec_detached.mean() + else: + x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) + + img_rec_detached = -x_r_detached.log_prob(obs).float().mean() losses['loss_reconstruction_img'] = img_rec_detached else: @@ -279,8 +302,15 @@ def KL(dist1, dist2, free_nat = True): img_mask = self.slot_mask(masks) decoded_imgs = decoded_imgs * img_mask - x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) - img_rec = -x_r.log_prob(obs).float() + if self.per_slot_rec_loss: + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # magic constant that describes the difference between log_prob and mse losses + img_rec = l2_loss.mean() * normalizing_factor * 8 + img_rec = img_rec.mean() + else: + x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) + img_rec = -x_r.log_prob(obs).float().mean() decoded_feats, masks = self.dino_predictor(state_slots.flatten(0, 1)).reshape(b, -1, self.vit_feat_dim+1, self.vit_size, self.vit_size).split([self.vit_feat_dim, 1], dim=-3) feat_mask = self.slot_mask(masks) @@ -289,8 +319,16 @@ def KL(dist1, dist2, free_nat = True): decoded_feats = decoded_feats * feat_mask - d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=-4), 1.0), 3) - d_rec = -d_pred.log_prob(d_obs).float().mean() + if self.per_slot_rec_loss: + l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) + # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) + # # magic constant that describes the difference between log_prob and mse losses + d_rec = l2_loss.mean() * normalizing_factor * 4 + d_rec = d_rec.mean() + else: + d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=-4), 1.0), 3) + d_rec = -d_pred.log_prob(d_obs).float().mean() d_rec = d_rec / torch.prod(torch.tensor(d_obs.shape[-3:])) * torch.prod(torch.tensor(obs.shape[-3:])) diff --git a/rl_sandbox/config/config_postslot.yaml b/rl_sandbox/config/config_postslot.yaml index d62809b..b7a0a02 100644 --- a/rl_sandbox/config/config_postslot.yaml +++ b/rl_sandbox/config/config_postslot.yaml @@ -17,7 +17,7 @@ agent: slots_iter_num: 3 encode_vit: false - decode_vit: true + decode_vit: false mask_combination: soft vit_l2_ratio: 1.0 @@ -26,6 +26,9 @@ agent: kl_loss_balancing: 0.8 kl_free_nats: 1.0 + use_reshuffle: true + per_slot_rec_loss: true + wm_optim: lr_scheduler: - _target_: rl_sandbox.utils.optimizer.WarmupScheduler @@ -33,7 +36,7 @@ agent: warmup_steps: 1e3 logger: - message: Post-wm slot attention, with dino, finger spin, n_dim=256, state_feature=5, slots=5 + message: Post-wm slot attention, per slot rec, finger spin, n_dim=256, state_feature=5, slots=5 log_grads: false training: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 981b1df..78787f6 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -147,7 +147,6 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) - class SlottedDreamerMetricsEvaluator(DreamerMetricsEvaluator): def on_step(self, logger): self.stored_steps += 1 @@ -262,6 +261,13 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) class SlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze() + def _generate_video(self, obs: list[Observation], actions: list[Action], d_feats: list[torch.Tensor], update_num: int): # obs = torch.from_numpy(obs.copy()).to(self.agent.device) # obs = self.agent.preprocess_obs(obs) From 38697554c4aed1ce7cfb43c23416ef7c956585c8 Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 13 Aug 2023 20:46:35 +0100 Subject: [PATCH 103/106] Fixed per slot rec loss --- .../agents/dreamer/world_model_post_slot.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py index 14c7734..a8808fd 100644 --- a/rl_sandbox/agents/dreamer/world_model_post_slot.py +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -268,11 +268,11 @@ def KL(dist1, dist2, free_nat = True): decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + l2_loss = (((decoded_imgs - img_mask * obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec = l2_loss.mean() * normalizing_factor * 8 - img_rec = img_rec.mean() + img_rec = l2_loss * normalizing_factor * 8 + img_rec = img_rec.sum(dim=1).mean() else: x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean(dim=0) @@ -286,11 +286,11 @@ def KL(dist1, dist2, free_nat = True): img_rec = torch.tensor(0, device=obs.device) if self.per_slot_rec_loss: - l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + l2_loss = (((decoded_imgs_detached - img_mask * obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec_detached = l2_loss.mean() * normalizing_factor * 8 - img_rec_detached = img_rec_detached.mean() + img_rec_detached = l2_loss * normalizing_factor * 8 + img_rec_detached = img_rec_detached.sum(dim=1).mean() else: x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) @@ -303,11 +303,11 @@ def KL(dist1, dist2, free_nat = True): decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + l2_loss = ((decoded_imgs - img_mask * obs.unsqueeze(1))**2).sum(dim=[2, 3, 4]) normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) # magic constant that describes the difference between log_prob and mse losses - img_rec = l2_loss.mean() * normalizing_factor * 8 - img_rec = img_rec.mean() + img_rec = l2_loss * normalizing_factor * 8 + img_rec = img_rec.sum(dim=1).mean() else: x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() @@ -320,12 +320,12 @@ def KL(dist1, dist2, free_nat = True): decoded_feats = decoded_feats * feat_mask if self.per_slot_rec_loss: - l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + l2_loss = ((decoded_feats - feat_mask*d_obs.unsqueeze(1))**2).sum(dim=[2, 3, 4]) normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) # # magic constant that describes the difference between log_prob and mse losses - d_rec = l2_loss.mean() * normalizing_factor * 4 - d_rec = d_rec.mean() + d_rec = l2_loss * normalizing_factor * 4 + d_rec = d_rec.sum(dim=1).mean() else: d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=-4), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() From ac5998d2ca71d7bfe44bba42d4e77d62b5442e0a Mon Sep 17 00:00:00 2001 From: Midren Date: Sun, 13 Aug 2023 22:46:17 +0100 Subject: [PATCH 104/106] Totally fixed per slot rec for post slot --- .../agents/dreamer/world_model_post_slot.py | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py index a8808fd..375f277 100644 --- a/rl_sandbox/agents/dreamer/world_model_post_slot.py +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -1,6 +1,7 @@ import typing as t import torch +import math import torch.distributions as td import torchvision as tv from torch import nn @@ -265,15 +266,16 @@ def KL(dist1, dist2, free_nat = True): # losses['loss_reconstruction'] = -x_r.log_prob(obs).float().mean() decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) img_mask = self.slot_mask(masks) - decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = (((decoded_imgs - img_mask * obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) - normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) - # magic constant that describes the difference between log_prob and mse losses - img_rec = l2_loss * normalizing_factor * 8 - img_rec = img_rec.sum(dim=1).mean() + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + img_rec = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + img_rec = img_rec.mean() + + decoded_imgs = decoded_imgs * img_mask else: + decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean(dim=0) losses['loss_reconstruction'] = img_rec @@ -281,17 +283,17 @@ def KL(dist1, dist2, free_nat = True): if self.vit_l2_ratio == 1.0: decoded_imgs_detached, masks = self.image_predictor(state_slots.flatten(0, 1).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) img_mask = self.slot_mask(masks) - decoded_imgs_detached = decoded_imgs_detached * img_mask img_rec = torch.tensor(0, device=obs.device) if self.per_slot_rec_loss: - l2_loss = (((decoded_imgs_detached - img_mask * obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) - normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) - # magic constant that describes the difference between log_prob and mse losses - img_rec_detached = l2_loss * normalizing_factor * 8 - img_rec_detached = img_rec_detached.sum(dim=1).mean() + l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + img_rec_detached = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + img_rec_detached = img_rec_detached.mean() + decoded_imgs_detached = decoded_imgs_detached * img_mask else: + decoded_imgs_detached = decoded_imgs_detached * img_mask x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) img_rec_detached = -x_r_detached.log_prob(obs).float().mean() @@ -300,15 +302,16 @@ def KL(dist1, dist2, free_nat = True): else: decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) img_mask = self.slot_mask(masks) - decoded_imgs = decoded_imgs * img_mask if self.per_slot_rec_loss: - l2_loss = ((decoded_imgs - img_mask * obs.unsqueeze(1))**2).sum(dim=[2, 3, 4]) - normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) - # magic constant that describes the difference between log_prob and mse losses - img_rec = l2_loss * normalizing_factor * 8 - img_rec = img_rec.sum(dim=1).mean() + l2_loss = (img_mask * ((decoded_imgs - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + img_rec = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + img_rec = img_rec.mean() + + decoded_imgs = decoded_imgs * img_mask else: + decoded_imgs = decoded_imgs * img_mask x_r = td.Independent(td.Normal(torch.sum(decoded_imgs, dim=-4), 1.0), 3) img_rec = -x_r.log_prob(obs).float().mean() @@ -320,12 +323,10 @@ def KL(dist1, dist2, free_nat = True): decoded_feats = decoded_feats * feat_mask if self.per_slot_rec_loss: - l2_loss = ((decoded_feats - feat_mask*d_obs.unsqueeze(1))**2).sum(dim=[2, 3, 4]) - normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) - # l2_loss = (feat_mask*((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=2).max(dim=2).values.max(dim=2).values * (64*64*3) - # # magic constant that describes the difference between log_prob and mse losses - d_rec = l2_loss * normalizing_factor * 4 - d_rec = d_rec.sum(dim=1).mean() + l2_loss = (feat_mask * ((decoded_feats - d_obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + normalizing_factor = torch.prod(torch.tensor(d_obs.shape)[-3:]) / feat_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + d_rec = l2_loss * normalizing_factor + torch.prod(torch.tensor(d_obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + d_rec = d_rec.mean() else: d_pred = td.Independent(td.Normal(torch.sum(decoded_feats, dim=-4), 1.0), 3) d_rec = -d_pred.log_prob(d_obs).float().mean() From 72449dc2c38636cfc1ea00e71b6527870609ce9d Mon Sep 17 00:00:00 2001 From: Midren Date: Fri, 18 Aug 2023 17:48:08 +0100 Subject: [PATCH 105/106] Add config option to tweak spatial decoder --- .../agents/dreamer/world_model_post_slot.py | 32 ++++++++++--------- .../dreamer/world_model_slots_attention.py | 29 +++++++++++------ .../agent/dreamer_v2_slotted_attention.yaml | 2 ++ rl_sandbox/config/config_postslot.yaml | 3 +- rl_sandbox/metrics.py | 7 ---- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py index 375f277..4cf83fa 100644 --- a/rl_sandbox/agents/dreamer/world_model_post_slot.py +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -21,7 +21,7 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, actions_num, discount_loss_scale, kl_loss_scale, kl_loss_balancing, kl_free_nats, discrete_rssm, predict_discount, layer_norm: bool, encode_vit: bool, decode_vit: bool, vit_l2_ratio: float, vit_img_size: int, slots_num: int, slots_iter_num: int, - mask_combination: str, use_reshuffle: bool, per_slot_rec_loss: bool): + mask_combination: str, use_reshuffle: bool, per_slot_rec_loss: bool, spatial_decoder: bool): super().__init__() self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) self.discount_scale = discount_loss_scale @@ -88,20 +88,22 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, self.n_dim = 256 if decode_vit: - # self.dino_predictor = SpatialBroadcastDecoder(self.n_dim, - # norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - # out_image=(14, 14), - # kernel_sizes = [5, 5, 5], - # channel_step=self.vit_feat_dim, - # output_channels=self.vit_feat_dim+1, - # return_dist=False) - self.dino_predictor = Decoder(self.n_dim, - norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3], - channel_step=self.vit_feat_dim, - kernel_sizes=self.decoder_kernels, - output_channels=self.vit_feat_dim+1, - return_dist=False) + if spatial_decoder: + self.dino_predictor = SpatialBroadcastDecoder(self.n_dim, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + out_image=(14, 14), + kernel_sizes = [5, 5, 5], + channel_step=self.vit_feat_dim, + output_channels=self.vit_feat_dim+1, + return_dist=False) + else: + self.dino_predictor = Decoder(self.n_dim, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) self.slots_num = slots_num self.mask_combination = mask_combination diff --git a/rl_sandbox/agents/dreamer/world_model_slots_attention.py b/rl_sandbox/agents/dreamer/world_model_slots_attention.py index b5ccd02..ed97145 100644 --- a/rl_sandbox/agents/dreamer/world_model_slots_attention.py +++ b/rl_sandbox/agents/dreamer/world_model_slots_attention.py @@ -8,7 +8,7 @@ from rl_sandbox.agents.dreamer import Dist, Normalizer, View, get_position_encoding from rl_sandbox.agents.dreamer.rssm_slots_attention import RSSM, State -from rl_sandbox.agents.dreamer.vision import Decoder, Encoder +from rl_sandbox.agents.dreamer.vision import SpatialBroadcastDecoder, Decoder, Encoder from rl_sandbox.utils.dists import DistLayer from rl_sandbox.utils.fc_nn import fc_nn_generator from rl_sandbox.vision.dino import ViTFeat @@ -25,7 +25,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, symmetric_qk: bool = False, attention_block_num: int = 3, mask_combination: str = 'soft', - per_slot_rec_loss: bool = False): + per_slot_rec_loss: bool = False, + spatial_decoder: bool = False): super().__init__() self.use_prev_slots = use_prev_slots self.register_buffer('kl_free_nats', kl_free_nats * torch.ones(1)) @@ -107,13 +108,23 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, nn.Linear(self.n_dim, self.n_dim)) if decode_vit: - self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, - norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - conv_kernel_sizes=[3], - channel_step=2*self.vit_feat_dim, - kernel_sizes=self.decoder_kernels, - output_channels=self.vit_feat_dim+1, - return_dist=False) + if spatial_decoder: + self.dino_predictor = SpatialBroadcastDecoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + out_image=(14, 14), + kernel_sizes = [5, 5, 5], + channel_step=self.vit_feat_dim, + output_channels=self.vit_feat_dim+1, + return_dist=False) + else: + self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + conv_kernel_sizes=[3], + channel_step=self.vit_feat_dim, + kernel_sizes=self.decoder_kernels, + output_channels=self.vit_feat_dim+1, + return_dist=False) + self.image_predictor = Decoder( rssm_dim + latent_dim * latent_classes, norm_layer=nn.GroupNorm if layer_norm else nn.Identity, diff --git a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml index 0a5f68f..2830e33 100644 --- a/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml +++ b/rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml @@ -19,6 +19,8 @@ world_model: symmetric_qk: false attention_block_num: 3 + spatial_decoder: false + wm_optim: lr_scheduler: - _target_: rl_sandbox.utils.optimizer.WarmupScheduler diff --git a/rl_sandbox/config/config_postslot.yaml b/rl_sandbox/config/config_postslot.yaml index b7a0a02..285244c 100644 --- a/rl_sandbox/config/config_postslot.yaml +++ b/rl_sandbox/config/config_postslot.yaml @@ -27,7 +27,8 @@ agent: kl_free_nats: 1.0 use_reshuffle: true - per_slot_rec_loss: true + per_slot_rec_loss: false + spatial_decoder: false wm_optim: lr_scheduler: diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 78787f6..8d85d28 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -261,13 +261,6 @@ def viz_log(self, rollout, logger, epoch_num): logger.add_scalar(f'val/reward', real_rewards[0].sum(), epoch_num) class SlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): - def on_step(self, logger): - self.stored_steps += 1 - - if self.agent.is_discrete: - self._action_probs += self._action_probs - self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze() - def _generate_video(self, obs: list[Observation], actions: list[Action], d_feats: list[torch.Tensor], update_num: int): # obs = torch.from_numpy(obs.copy()).to(self.agent.device) # obs = self.agent.preprocess_obs(obs) From edad471a25a5081235e167566b364e078f4d9468 Mon Sep 17 00:00:00 2001 From: Midren Date: Sat, 2 Sep 2023 15:08:11 +0100 Subject: [PATCH 106/106] No image predictor in dino only, new envs --- .../agents/dreamer/world_model_post_slot.py | 36 ++++----- rl_sandbox/agents/dreamer_v2.py | 8 +- rl_sandbox/config/agent/dreamer_v2.yaml | 4 +- rl_sandbox/config/config.yaml | 59 +++++++------- rl_sandbox/config/config_postslot.yaml | 28 +++---- rl_sandbox/config/config_postslot_dino.yaml | 71 +++++++++++++++++ .../config/env/dm_finger_turn_easy.yaml | 8 ++ .../config/env/dm_finger_turn_hard.yaml | 8 ++ rl_sandbox/config/env/dm_hopper_hop.yaml | 8 ++ rl_sandbox/config/env/dm_quadruped.yaml | 2 +- rl_sandbox/config/env/dm_quadruped_walk.yaml | 8 ++ rl_sandbox/config/env/dm_reacher_hard.yaml | 8 ++ rl_sandbox/config/env/dm_walker.yaml | 2 +- rl_sandbox/config/env/dm_walker_stand.yaml | 8 ++ rl_sandbox/config/training/dm.yaml | 4 +- rl_sandbox/metrics.py | 76 +++++++++++++++++++ rl_sandbox/train.py | 5 +- 17 files changed, 270 insertions(+), 73 deletions(-) create mode 100644 rl_sandbox/config/config_postslot_dino.yaml create mode 100644 rl_sandbox/config/env/dm_finger_turn_easy.yaml create mode 100644 rl_sandbox/config/env/dm_finger_turn_hard.yaml create mode 100644 rl_sandbox/config/env/dm_hopper_hop.yaml create mode 100644 rl_sandbox/config/env/dm_quadruped_walk.yaml create mode 100644 rl_sandbox/config/env/dm_reacher_hard.yaml create mode 100644 rl_sandbox/config/env/dm_walker_stand.yaml diff --git a/rl_sandbox/agents/dreamer/world_model_post_slot.py b/rl_sandbox/agents/dreamer/world_model_post_slot.py index 4cf83fa..aad0e55 100644 --- a/rl_sandbox/agents/dreamer/world_model_post_slot.py +++ b/rl_sandbox/agents/dreamer/world_model_post_slot.py @@ -121,10 +121,11 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim, # nn.ReLU(inplace=True), # nn.Linear(self.n_dim, self.n_dim)) - self.image_predictor = Decoder(self.n_dim, - output_channels=4, - norm_layer=nn.GroupNorm if layer_norm else nn.Identity, - return_dist=False) + if not decode_vit: + self.image_predictor = Decoder(self.n_dim, + output_channels=4, + norm_layer=nn.GroupNorm if layer_norm else nn.Identity, + return_dist=False) self.reward_predictor = fc_nn_generator(self.state_size, 1, @@ -283,24 +284,25 @@ def KL(dist1, dist2, free_nat = True): losses['loss_reconstruction'] = img_rec else: if self.vit_l2_ratio == 1.0: - decoded_imgs_detached, masks = self.image_predictor(state_slots.flatten(0, 1).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) - img_mask = self.slot_mask(masks) + pass + # decoded_imgs_detached, masks = self.image_predictor(state_slots.flatten(0, 1).detach()).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) + # img_mask = self.slot_mask(masks) img_rec = torch.tensor(0, device=obs.device) - if self.per_slot_rec_loss: - l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) - normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 - img_rec_detached = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) - img_rec_detached = img_rec_detached.mean() - decoded_imgs_detached = decoded_imgs_detached * img_mask - else: - decoded_imgs_detached = decoded_imgs_detached * img_mask - x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) + # if self.per_slot_rec_loss: + # l2_loss = (img_mask * ((decoded_imgs_detached - obs.unsqueeze(1))**2)).sum(dim=[2, 3, 4]) + # normalizing_factor = torch.prod(torch.tensor(obs.shape)[-3:]) / img_mask.sum(dim=[2, 3, 4]).clamp(min=1) / 3 + # img_rec_detached = l2_loss * normalizing_factor + torch.prod(torch.tensor(obs.shape)[-3:]) * math.log((2*math.pi)**(1/2)) + # img_rec_detached = img_rec_detached.mean() + # decoded_imgs_detached = decoded_imgs_detached * img_mask + # else: + # decoded_imgs_detached = decoded_imgs_detached * img_mask + # x_r_detached = td.Independent(td.Normal(torch.sum(decoded_imgs_detached, dim=-4), 1.0), 3) - img_rec_detached = -x_r_detached.log_prob(obs).float().mean() + # img_rec_detached = -x_r_detached.log_prob(obs).float().mean() - losses['loss_reconstruction_img'] = img_rec_detached + # losses['loss_reconstruction_img'] = img_rec_detached else: decoded_imgs, masks = self.image_predictor(state_slots.flatten(0, 1)).reshape(b, -1, 4, h, w).split([3, 1], dim=-3) img_mask = self.slot_mask(masks) diff --git a/rl_sandbox/agents/dreamer_v2.py b/rl_sandbox/agents/dreamer_v2.py index 3c61120..cbecf7b 100644 --- a/rl_sandbox/agents/dreamer_v2.py +++ b/rl_sandbox/agents/dreamer_v2.py @@ -58,8 +58,8 @@ def __init__( self.critic: ImaginativeCritic = torch.compile(critic(latent_dim=self.world_model.state_size), mode='max-autotune').to(device_type) self.world_model_optimizer = wm_optim(model=self.world_model, scaler=self.is_f16) - if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: - self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor, scaler=self.is_f16) + # if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: + # self.image_predictor_optimizer = wm_optim(model=self.world_model.image_predictor, scaler=self.is_f16) self.actor_optimizer = actor_optim(model=self.actor) self.critic_optimizer = critic_optim(model=self.critic) @@ -170,8 +170,8 @@ def train(self, rollout_chunks: RolloutChunks): # FIXME: wholely remove discrete RSSM # self.world_model.recurrent_model.discretizer_scheduler.step() - if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: - self.image_predictor_optimizer.step(losses_wm['loss_reconstruction_img']) + # if self.world_model.decode_vit and self.world_model.vit_l2_ratio == 1.0: + # self.image_predictor_optimizer.step(losses_wm['loss_reconstruction_img']) metrics_wm |= self.world_model_optimizer.step(losses_wm['loss_wm']) diff --git a/rl_sandbox/config/agent/dreamer_v2.yaml b/rl_sandbox/config/agent/dreamer_v2.yaml index bd74ca0..5d91ce5 100644 --- a/rl_sandbox/config/agent/dreamer_v2.yaml +++ b/rl_sandbox/config/agent/dreamer_v2.yaml @@ -13,7 +13,7 @@ world_model: latent_classes: 32 rssm_dim: 200 discount_loss_scale: 1.0 - kl_loss_scale: 1 + kl_loss_scale: 2 kl_loss_balancing: 0.8 kl_free_nats: 1.00 discrete_rssm: false @@ -30,7 +30,7 @@ actor: # mixing of reinforce and maximizing value func # for dm_control it is zero in Dreamer (Atari 1) reinforce_fraction: null - entropy_scale: 1e-4 + entropy_scale: 1e-5 layer_norm: ${..layer_norm} critic: diff --git a/rl_sandbox/config/config.yaml b/rl_sandbox/config/config.yaml index 0bafcc2..c777d19 100644 --- a/rl_sandbox/config/config.yaml +++ b/rl_sandbox/config/config.yaml @@ -1,31 +1,35 @@ defaults: - - agent: dreamer_v2_crafter - - env: crafter - - training: crafter + - agent: dreamer_v2 + - env: dm_cartpole + - training: dm - logger: wandb - _self_ - override hydra/launcher: joblib -seed: 43 -device_type: cuda:1 +seed: 42 +device_type: cuda agent: world_model: - decode_vit: true - vit_img_size: 224 - vit_l2_ratio: 1.0 - kl_loss_scale: 3.0 - kl_loss_balancing: 0.8 - kl_free_nats: 1.0 + _target_: rl_sandbox.agents.dreamer.world_model.WorldModel + rssm_dim: 200 - actor_optim: - lr: 2e-4 + encode_vit: false + decode_vit: false + #vit_l2_ratio: 1.0 - critic_optim: - lr: 2e-4 + #kl_loss_scale: 2.0 + #kl_loss_balancing: 0.8 + #kl_free_nats: 1.0 + + #wm_optim: + # lr_scheduler: + # - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + # _partial_: true + # warmup_steps: 1e3 logger: - message: Dreamer with only dino 0.8/3, fp16 + message: Default dreamer fp16 log_grads: false training: @@ -43,21 +47,20 @@ validation: _partial_: true - _target_: rl_sandbox.metrics.DreamerMetricsEvaluator _partial_: true - - _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator - _partial_: true + #- _target_: rl_sandbox.metrics.PostSlottedDreamerMetricsEvaluator + # _partial_: true + #- _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator + # _partial_: true debug: profiler: false hydra: - #mode: MULTIRUN - mode: RUN + mode: MULTIRUN + #mode: RUN launcher: - n_jobs: 1 - #sweeper: - # params: - # agent.world_model._target_: rl_sandbox.agents.dreamer.world_model_slots_combined.WorldModel,rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel - # agent.world_model.vit_l2_ratio: 0.1,0.5 - # agent.world_model.kl_loss_scale: 1e1,1e2,1e3,1e4 - # agent.world_model.vit_l2_ratio: 0.1,0.9 - + n_jobs: 3 + sweeper: + params: + seed: 17,42,45 + env: dm_finger_spin,dm_finger_turn_hard diff --git a/rl_sandbox/config/config_postslot.yaml b/rl_sandbox/config/config_postslot.yaml index 285244c..85b4b9f 100644 --- a/rl_sandbox/config/config_postslot.yaml +++ b/rl_sandbox/config/config_postslot.yaml @@ -1,13 +1,13 @@ defaults: - agent: dreamer_v2 - - env: dm_finger_spin + - env: dm_acrobot - training: dm - logger: wandb - _self_ - override hydra/launcher: joblib seed: 42 -device_type: cuda:0 +device_type: cuda agent: world_model: @@ -37,11 +37,11 @@ agent: warmup_steps: 1e3 logger: - message: Post-wm slot attention, per slot rec, finger spin, n_dim=256, state_feature=5, slots=5 + message: Post-wm slot attention, n_dim=256 log_grads: false training: - f16_precision: false + f16_precision: true checkpoint_path: null steps: 1e6 val_logs_every: 2e4 @@ -55,21 +55,17 @@ validation: _partial_: true - _target_: rl_sandbox.metrics.PostSlottedDreamerMetricsEvaluator _partial_: true - #- _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator - # _partial_: true debug: profiler: false hydra: - #mode: MULTIRUN - mode: RUN + mode: MULTIRUN + #mode: RUN launcher: - n_jobs: 1 - # n_jobs: 8 - #sweeper: - # params: - # agent.world_model.kl_loss_scale: 1e-3,5e-3,1e-2,5e-2,5,10,15,20,25,50,75,100,250 - # agent.world_model.full_qk_from: 1,2e4 - # agent.world_model.symmetric_qk: true,false - # agent.world_model.attention_block_num: 1,3 + n_jobs: 3 + sweeper: + params: + seed: 17,42,45 + env: dm_finger_spin,dm_finger_turn_hard + diff --git a/rl_sandbox/config/config_postslot_dino.yaml b/rl_sandbox/config/config_postslot_dino.yaml new file mode 100644 index 0000000..94e9a11 --- /dev/null +++ b/rl_sandbox/config/config_postslot_dino.yaml @@ -0,0 +1,71 @@ +defaults: + - agent: dreamer_v2 + - env: dm_acrobot + - training: dm + - logger: wandb + - _self_ + - override hydra/launcher: joblib + +seed: 42 +device_type: cuda + +agent: + world_model: + _target_: rl_sandbox.agents.dreamer.world_model_post_slot.WorldModel + rssm_dim: 256 + slots_num: 5 + slots_iter_num: 3 + + encode_vit: false + decode_vit: true + mask_combination: soft + vit_l2_ratio: 1.0 + + vit_img_size: 224 + kl_loss_scale: 1.0 + kl_loss_balancing: 0.8 + kl_free_nats: 1.0 + + use_reshuffle: true + per_slot_rec_loss: false + spatial_decoder: false + + wm_optim: + lr_scheduler: + - _target_: rl_sandbox.utils.optimizer.WarmupScheduler + _partial_: true + warmup_steps: 1e3 + +logger: + message: Post-wm dino slot attention, n_dim=256 + log_grads: false + +training: + f16_precision: true + checkpoint_path: null + steps: 1e6 + val_logs_every: 2e4 + +validation: + rollout_num: 5 + visualize: true + metrics: + - _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator + log_video: True + _partial_: true + - _target_: rl_sandbox.metrics.PostSlottedDinoDreamerMetricsEvaluator + _partial_: true + +debug: + profiler: false + +hydra: + #mode: MULTIRUN + mode: RUN + launcher: + n_jobs: 3 + sweeper: + params: + seed: 17,42,45 + env: dm_finger_spin,dm_finger_turn_hard + diff --git a/rl_sandbox/config/env/dm_finger_turn_easy.yaml b/rl_sandbox/config/env/dm_finger_turn_easy.yaml new file mode 100644 index 0000000..bbc6de7 --- /dev/null +++ b/rl_sandbox/config/env/dm_finger_turn_easy.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: finger +task_name: turn_easy +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_finger_turn_hard.yaml b/rl_sandbox/config/env/dm_finger_turn_hard.yaml new file mode 100644 index 0000000..b040df1 --- /dev/null +++ b/rl_sandbox/config/env/dm_finger_turn_hard.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: finger +task_name: turn_hard +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_hopper_hop.yaml b/rl_sandbox/config/env/dm_hopper_hop.yaml new file mode 100644 index 0000000..ff8998f --- /dev/null +++ b/rl_sandbox/config/env/dm_hopper_hop.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: hopper +task_name: hop +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_quadruped.yaml b/rl_sandbox/config/env/dm_quadruped.yaml index aa5e541..9f73398 100644 --- a/rl_sandbox/config/env/dm_quadruped.yaml +++ b/rl_sandbox/config/env/dm_quadruped.yaml @@ -1,6 +1,6 @@ _target_: rl_sandbox.utils.env.DmEnv domain_name: quadruped -task_name: walk +task_name: run run_on_pixels: true obs_res: [64, 64] camera_id: 2 diff --git a/rl_sandbox/config/env/dm_quadruped_walk.yaml b/rl_sandbox/config/env/dm_quadruped_walk.yaml new file mode 100644 index 0000000..aa5e541 --- /dev/null +++ b/rl_sandbox/config/env/dm_quadruped_walk.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: quadruped +task_name: walk +run_on_pixels: true +obs_res: [64, 64] +camera_id: 2 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_reacher_hard.yaml b/rl_sandbox/config/env/dm_reacher_hard.yaml new file mode 100644 index 0000000..6ecbd96 --- /dev/null +++ b/rl_sandbox/config/env/dm_reacher_hard.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: reacher +task_name: hard +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/env/dm_walker.yaml b/rl_sandbox/config/env/dm_walker.yaml index 68d8cf5..c97057c 100644 --- a/rl_sandbox/config/env/dm_walker.yaml +++ b/rl_sandbox/config/env/dm_walker.yaml @@ -1,6 +1,6 @@ _target_: rl_sandbox.utils.env.DmEnv domain_name: walker -task_name: walk +task_name: run run_on_pixels: true obs_res: [64, 64] camera_id: 0 diff --git a/rl_sandbox/config/env/dm_walker_stand.yaml b/rl_sandbox/config/env/dm_walker_stand.yaml new file mode 100644 index 0000000..ff2f83a --- /dev/null +++ b/rl_sandbox/config/env/dm_walker_stand.yaml @@ -0,0 +1,8 @@ +_target_: rl_sandbox.utils.env.DmEnv +domain_name: walker +task_name: stand +run_on_pixels: true +obs_res: [64, 64] +camera_id: 0 +repeat_action_num: 2 +transforms: [] diff --git a/rl_sandbox/config/training/dm.yaml b/rl_sandbox/config/training/dm.yaml index a4328ba..67fd27f 100644 --- a/rl_sandbox/config/training/dm.yaml +++ b/rl_sandbox/config/training/dm.yaml @@ -1,9 +1,9 @@ steps: 1e6 prefill: 1000 -batch_size: 16 +batch_size: 50 pretrain: 100 prioritize_ends: false -train_every: 5 +train_every: 4 save_checkpoint_every: 2e6 val_logs_every: 2e4 f16_precision: false diff --git a/rl_sandbox/metrics.py b/rl_sandbox/metrics.py index 8d85d28..ce7f5ac 100644 --- a/rl_sandbox/metrics.py +++ b/rl_sandbox/metrics.py @@ -440,3 +440,79 @@ def _generate_video(self, obs: list[Observation], actions: list[Action], update_ return torch.cat(video), rews, torch.cat(slots_video) + +class PostSlottedDinoDreamerMetricsEvaluator(SlottedDreamerMetricsEvaluator): + def on_step(self, logger): + self.stored_steps += 1 + + if self.agent.is_discrete: + self._action_probs += self._action_probs + self._latent_probs += self.agent._state[0].stoch_dist.base_dist.probs.squeeze() + + def _generate_video(self, obs: list[Observation], actions: list[Action], update_num: int): + # obs = torch.from_numpy(obs.copy()).to(self.agent.device) + # obs = self.agent.preprocess_obs(obs) + # actions = self.agent.from_np(actions) + if self.agent.is_discrete: + actions = F.one_hot(actions.to(torch.int64), num_classes=self.agent.actions_num).squeeze() + video = [] + slots_video = [] + rews = [] + + vit_size = self.agent.world_model.vit_size + + state = None + prev_slots = None + for idx, (o, a) in enumerate(list(zip(obs, actions))): + if idx > update_num: + break + state, prev_slots = self.agent.world_model.get_latent(o, a.unsqueeze(0).unsqueeze(0), (state, prev_slots)) + # video_r = self.agent.world_model.image_predictor(state.combined_slots).mode + + wm_state = self.agent.world_model.state_reshuffle(state.combined) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + # decoded_imgs, masks = self.agent.world_model.image_predictor(wm_state_slots.flatten(0, 1)).reshape(1, -1, 4, 64, 64).split([3, 1], dim=2) + # TODO: try the scaling of softmax as in attention + # img_mask = self.agent.world_model.slot_mask(masks) + # decoded_imgs = decoded_imgs * img_mask + # video_r = torch.sum(decoded_imgs, dim=1) + + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(wm_state_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = self.agent.world_model.slot_mask(vit_masks) + decoded_dino_feats = decoded_dino_feats * vit_mask + decoded_dino = (decoded_dino_feats).sum(dim=1) + upscale = tv.transforms.Resize(64, antialias=True) + + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(1) * o.to(self.agent.device).unsqueeze(0)).unsqueeze(0) + + rews.append(self.agent.world_model.reward_predictor(state.combined).mode.item()) + video.append(self.agent.unprocess_obs(o).unsqueeze(0)) + slots_video.append(self.agent.unprocess_obs(per_slot_vit)) + + rews = torch.Tensor(rews).to(obs.device) + + if update_num < len(obs): + states, _, rews_2, _ = self.agent.imagine_trajectory(state, actions[update_num+1:].unsqueeze(1), horizon=self.agent.imagination_horizon - 1 - update_num) + rews = torch.cat([rews, rews_2[1:].squeeze()]) + + wm_state = self.agent.world_model.state_reshuffle(states.combined[1:]) + wm_state = wm_state.reshape(*wm_state.shape[:-1], self.agent.world_model.state_feature_num, self.agent.world_model.n_dim) + wm_state_pos_embedded = self.agent.world_model.positional_augmenter_inp(wm_state.unsqueeze(-3)).squeeze(-3) + wm_state_slots = self.agent.world_model.slot_attention(wm_state_pos_embedded.flatten(0, 1), None) + + decoded_dino_feats, vit_masks = self.agent.world_model.dino_predictor(wm_state_slots.flatten(0, 1)).reshape(-1, self.agent.world_model.slots_num, self.agent.world_model.vit_feat_dim+1, vit_size, vit_size).split([self.agent.world_model.vit_feat_dim, 1], dim=2) + vit_mask = F.softmax(vit_masks, dim=1) + decoded_dino = (decoded_dino_feats * vit_mask).sum(dim=1) + + upscale = tv.transforms.Resize(64, antialias=True) + upscaled_mask = upscale(vit_mask.permute(0, 1, 4, 2, 3).squeeze()) + per_slot_vit = (upscaled_mask.unsqueeze(2) * obs[update_num+1:].to(self.agent.device).unsqueeze(1)) + + video.append(self.agent.unprocess_obs(obs[update_num+1:])) + slots_video.append(self.agent.unprocess_obs(per_slot_vit)) + + return torch.cat(video), rews, torch.cat(slots_video) diff --git a/rl_sandbox/train.py b/rl_sandbox/train.py index 5c61b59..caca930 100644 --- a/rl_sandbox/train.py +++ b/rl_sandbox/train.py @@ -1,7 +1,7 @@ import random import os os.environ['MUJOCO_GL'] = 'egl' -# os.environ["WANDB_MODE"]="offline" +os.environ["WANDB_MODE"]="offline" import crafter import hydra @@ -76,7 +76,8 @@ def main(cfg: DictConfig): f16_precision=cfg.training.f16_precision, logger=logger) - buff = ReplayBuffer(prioritize_ends=cfg.training.prioritize_ends, + buff = ReplayBuffer(max_len=500_000, + prioritize_ends=cfg.training.prioritize_ends, min_ep_len=cfg.agent.get('batch_cluster_size', 1) * (cfg.training.prioritize_ends + 1), preprocess_func=agent.preprocess,