diff --git a/libs/uvit.py b/libs/uvit.py index ee43e16..a944539 100644 --- a/libs/uvit.py +++ b/libs/uvit.py @@ -97,6 +97,7 @@ class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): super().__init__() + self.norm_skip = norm_layer(2*dim) if skip else None self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) @@ -114,7 +115,8 @@ def forward(self, x, skip=None): def _forward(self, x, skip=None): if self.skip_linear is not None: - x = self.skip_linear(torch.cat([x, skip], dim=-1)) + skip_in_ = self.norm_skip(torch.cat([x, skip], dim=-1)) if self.norm_skip is not None else torch.cat([x, skip], dim=-1) + x = self.skip_linear(skip_in_) x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x