From 06ddb8900d40e409e493990ff98d8a39633fc104 Mon Sep 17 00:00:00 2001 From: schilcher Date: Sat, 1 Nov 2025 14:19:55 +0100 Subject: [PATCH] Fix numerical instabilities during training by normalizing inputs to long skip connections (analogous to MHSA and MLP) --- libs/uvit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/uvit.py b/libs/uvit.py index ee43e16..a944539 100644 --- a/libs/uvit.py +++ b/libs/uvit.py @@ -97,6 +97,7 @@ class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): super().__init__() + self.norm_skip = norm_layer(2*dim) if skip else None self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale) @@ -114,7 +115,8 @@ def forward(self, x, skip=None): def _forward(self, x, skip=None): if self.skip_linear is not None: - x = self.skip_linear(torch.cat([x, skip], dim=-1)) + skip_in_ = self.norm_skip(torch.cat([x, skip], dim=-1)) if self.norm_skip is not None else torch.cat([x, skip], dim=-1) + x = self.skip_linear(skip_in_) x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x