Skip to content

Commit 4e437f2

Browse files
committed
Try newer dino-only attention slotted
1 parent aaa5dc4 commit 4e437f2

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

rl_sandbox/agents/dreamer/vision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ class Encoder(nn.Module):
66

77
def __init__(self, norm_layer: nn.GroupNorm | nn.Identity,
88
channel_step=96,
9-
kernel_sizes=[4, 4, 4],
10-
double_conv=False,
9+
kernel_sizes=[4, 4, 4, 4],
10+
post_conv_num: int = 0,
1111
flatten_output=True,
1212
in_channels=3,
1313
):
@@ -21,7 +21,7 @@ def __init__(self, norm_layer: nn.GroupNorm | nn.Identity,
2121
layers.append(nn.ELU(inplace=True))
2222
in_channels = out_channels
2323

24-
for i, k in enumerate(kernel_sizes):
24+
for k in range(post_conv_num):
2525
layers.append(
2626
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'))
2727
layers.append(norm_layer(1, out_channels))

rl_sandbox/agents/dreamer/world_model_slots_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,17 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim,
9797
)
9898
else:
9999
self.encoder = Encoder(norm_layer=nn.GroupNorm if layer_norm else nn.Identity,
100-
kernel_sizes=[4, 4, 4],
101-
channel_step=48 * (self.n_dim // 192),
102-
double_conv=True,
100+
kernel_sizes=[4, 4],
101+
channel_step=48 * (self.n_dim // 192) * 2,
102+
post_conv_num=3,
103103
flatten_output=False)
104104

105105
self.slot_attention = SlotAttention(slots_num, self.n_dim, slots_iter_num, use_prev_slots)
106106
self.register_buffer('pos_enc', torch.from_numpy(get_position_encoding(self.slots_num, self.state_size // slots_num)).to(dtype=torch.float32))
107107
if self.encode_vit:
108108
self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (4, 4))
109109
else:
110-
self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (6, 6))
110+
self.positional_augmenter_inp = PositionalEmbedding(self.n_dim, (14, 14))
111111

112112
self.slot_mlp = nn.Sequential(nn.Linear(self.n_dim, self.n_dim),
113113
nn.ReLU(inplace=True),
@@ -116,8 +116,8 @@ def __init__(self, batch_cluster_size, latent_dim, latent_classes, rssm_dim,
116116
if decode_vit:
117117
self.dino_predictor = Decoder(rssm_dim + latent_dim * latent_classes,
118118
norm_layer=nn.GroupNorm if layer_norm else nn.Identity,
119-
conv_kernel_sizes=[],
120-
channel_step=self.vit_feat_dim,
119+
conv_kernel_sizes=[3, 3],
120+
channel_step=2*self.vit_feat_dim,
121121
kernel_sizes=self.decoder_kernels,
122122
output_channels=self.vit_feat_dim+1,
123123
return_dist=False)

rl_sandbox/config/agent/dreamer_v2_crafter_slotted.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ world_model:
77
rssm_dim: 512
88
slots_num: 6
99
slots_iter_num: 2
10-
kl_loss_scale: 1e2
10+
kl_loss_scale: 1.0
1111
decode_vit: true
12-
use_prev_slots: true
12+
use_prev_slots: false
1313
vit_l2_ratio: 0.1
1414
encode_vit: false

rl_sandbox/config/agent/dreamer_v2_slotted_attention.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ defaults:
44

55
world_model:
66
_target_: rl_sandbox.agents.dreamer.world_model_slots_attention.WorldModel
7-
rssm_dim: 1024
7+
rssm_dim: 768
88
slots_num: 4
99
slots_iter_num: 3
1010
kl_loss_scale: 1.0
@@ -13,7 +13,7 @@ world_model:
1313
mask_combination: soft
1414
use_prev_slots: false
1515
per_slot_rec_loss: false
16-
vit_l2_ratio: 0.1
16+
vit_l2_ratio: 0.5
1717

1818
full_qk_from: 4e4
1919
symmetric_qk: false

rl_sandbox/config/config_attention.yaml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,22 @@ defaults:
77
- override hydra/launcher: joblib
88

99
seed: 42
10-
device_type: cuda:0
10+
device_type: cuda:1
1111

1212
agent:
1313
world_model:
1414
encode_vit: false
15-
decode_vit: true
15+
decode_vit: false
1616
vit_img_size: 224
1717
vit_l2_ratio: 1.0
1818
slots_iter_num: 3
19-
slots_num: 6
19+
slots_num: 4
2020
kl_loss_scale: 2.0
21+
kl_loss_balancing: 0.8
2122
kl_free_nats: 1.0
2223

2324
logger:
24-
message: Attention, only dino, kl=20, removed symmetric, add warmup
25+
message: Attention, without dino, kl=2, removed symmetric, add warmup, 4 slots, 768,
2526
log_grads: false
2627

2728
training:
@@ -36,8 +37,10 @@ validation:
3637
- _target_: rl_sandbox.metrics.EpisodeMetricsEvaluator
3738
log_video: True
3839
_partial_: true
39-
- _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator
40-
#- _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator
40+
#- _target_: rl_sandbox.metrics.SlottedDinoDreamerMetricsEvaluator
41+
- _target_: rl_sandbox.metrics.SlottedDreamerMetricsEvaluator
42+
_partial_: true
43+
- _target_: rl_sandbox.crafter_metrics.CrafterMetricsEvaluator
4144
_partial_: true
4245

4346
debug:

0 commit comments

Comments
 (0)