@@ -29,6 +29,8 @@ def __init__(
2929 self ,
3030 dim : int ,
3131 num_heads : int = 8 ,
32+ attn_head_dim : Optional [int ] = None ,
33+ dim_out : Optional [int ] = None ,
3234 qkv_bias : bool = False ,
3335 qk_norm : bool = False ,
3436 scale_norm : bool = False ,
@@ -37,36 +39,45 @@ def __init__(
3739 proj_drop : float = 0. ,
3840 norm_layer : Optional [Type [nn .Module ]] = None ,
3941 device = None ,
40- dtype = None
42+ dtype = None ,
4143 ) -> None :
4244 """Initialize the Attention module.
4345
4446 Args:
45- dim: Input dimension of the token embeddings
46- num_heads: Number of attention heads
47- qkv_bias: Whether to use bias in the query, key, value projections
48- qk_norm: Whether to apply normalization to query and key vectors
49- proj_bias: Whether to use bias in the output projection
50- attn_drop: Dropout rate applied to the attention weights
51- proj_drop: Dropout rate applied after the output projection
52- norm_layer: Normalization layer constructor for QK normalization if enabled
47+ dim: Input dimension of the token embeddings.
48+ num_heads: Number of attention heads.
49+ attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
50+ dim_out: Output dimension. If None, same as dim.
51+ qkv_bias: Whether to use bias in the query, key, value projections.
52+ qk_norm: Whether to apply normalization to query and key vectors.
53+ scale_norm: Whether to apply normalization to attention output before projection.
54+ proj_bias: Whether to use bias in the output projection.
55+ attn_drop: Dropout rate applied to the attention weights.
56+ proj_drop: Dropout rate applied after the output projection.
57+ norm_layer: Normalization layer constructor for QK normalization if enabled.
5358 """
5459 super ().__init__ ()
5560 dd = {'device' : device , 'dtype' : dtype }
56- assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
61+ dim_out = dim_out or dim
62+ head_dim = attn_head_dim
63+ if head_dim is None :
64+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
65+ head_dim = dim // num_heads
5766 if qk_norm or scale_norm :
5867 assert norm_layer is not None , 'norm_layer must be provided if qk_norm or scale_norm is True'
68+
5969 self .num_heads = num_heads
60- self .head_dim = dim // num_heads
61- self .scale = self .head_dim ** - 0.5
70+ self .head_dim = head_dim
71+ self .attn_dim = num_heads * head_dim
72+ self .scale = head_dim ** - 0.5
6273 self .fused_attn = use_fused_attn ()
6374
64- self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias , ** dd )
65- self .q_norm = norm_layer (self . head_dim , ** dd ) if qk_norm else nn .Identity ()
66- self .k_norm = norm_layer (self . head_dim , ** dd ) if qk_norm else nn .Identity ()
75+ self .qkv = nn .Linear (dim , self . attn_dim * 3 , bias = qkv_bias , ** dd )
76+ self .q_norm = norm_layer (head_dim , ** dd ) if qk_norm else nn .Identity ()
77+ self .k_norm = norm_layer (head_dim , ** dd ) if qk_norm else nn .Identity ()
6778 self .attn_drop = nn .Dropout (attn_drop )
68- self .norm = norm_layer (dim , ** dd ) if scale_norm else nn .Identity ()
69- self .proj = nn .Linear (dim , dim , bias = proj_bias , ** dd )
79+ self .norm = norm_layer (self . attn_dim , ** dd ) if scale_norm else nn .Identity ()
80+ self .proj = nn .Linear (self . attn_dim , dim_out , bias = proj_bias , ** dd )
7081 self .proj_drop = nn .Dropout (proj_drop )
7182
7283 def forward (
@@ -93,7 +104,7 @@ def forward(
93104 attn = self .attn_drop (attn )
94105 x = attn @ v
95106
96- x = x .transpose (1 , 2 ).reshape (B , N , C )
107+ x = x .transpose (1 , 2 ).reshape (B , N , self . attn_dim )
97108 x = self .norm (x )
98109 x = self .proj (x )
99110 x = self .proj_drop (x )
@@ -114,6 +125,7 @@ def __init__(
114125 self ,
115126 dim : int ,
116127 num_heads : int = 8 ,
128+ dim_out : Optional [int ] = None ,
117129 qkv_bias : bool = True ,
118130 qkv_fused : bool = True ,
119131 num_prefix_tokens : int = 1 ,
@@ -133,45 +145,52 @@ def __init__(
133145 Args:
134146 dim: Input dimension of the token embeddings
135147 num_heads: Number of attention heads
148+ dim_out: Output dimension. If None, same as dim.
136149 qkv_bias: Whether to add a bias term to the query, key, and value projections
150+ qkv_fused: Whether to use fused QKV projection (single linear) or separate projections
137151 num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
138152 should not have position embeddings applied
139153 attn_drop: Dropout rate for attention weights
140154 proj_drop: Dropout rate for the output projection
141- attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
155+ attn_head_dim: Dimension of each attention head. If None, computed as dim // num_heads.
142156 norm_layer: Normalization layer constructor to use for QK and scale normalization
143157 qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
144158 scale_norm: Enable normalization (scaling) of attention output with norm_layer
159+ proj_bias: Whether to use bias in the output projection
145160 rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
146161 """
147162 super ().__init__ ()
148163 dd = {'device' : device , 'dtype' : dtype }
164+ dim_out = dim_out or dim
165+ head_dim = attn_head_dim
166+ if head_dim is None :
167+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
168+ head_dim = dim // num_heads
149169 if scale_norm or qk_norm :
150170 assert norm_layer is not None , 'norm_layer must be provided if qk_norm or scale_norm is True'
171+
151172 self .num_heads = num_heads
152- head_dim = dim // num_heads
153- if attn_head_dim is not None :
154- head_dim = attn_head_dim
155- attn_dim = head_dim * self .num_heads
173+ self .head_dim = head_dim
174+ self .attn_dim = head_dim * num_heads
156175 self .scale = head_dim ** - 0.5
157176 self .num_prefix_tokens = num_prefix_tokens
158177 self .fused_attn = use_fused_attn ()
159178 self .rotate_half = rotate_half
160179
161180 if qkv_fused :
162- self .qkv = nn .Linear (dim , attn_dim * 3 , bias = qkv_bias , ** dd )
181+ self .qkv = nn .Linear (dim , self . attn_dim * 3 , bias = qkv_bias , ** dd )
163182 self .q_proj = self .k_proj = self .v_proj = None
164183 else :
165184 self .qkv = None
166- self .q_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
167- self .k_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
168- self .v_proj = nn .Linear (dim , attn_dim , bias = qkv_bias , ** dd )
185+ self .q_proj = nn .Linear (dim , self . attn_dim , bias = qkv_bias , ** dd )
186+ self .k_proj = nn .Linear (dim , self . attn_dim , bias = qkv_bias , ** dd )
187+ self .v_proj = nn .Linear (dim , self . attn_dim , bias = qkv_bias , ** dd )
169188
170189 self .q_norm = norm_layer (head_dim , ** dd ) if qk_norm else nn .Identity ()
171190 self .k_norm = norm_layer (head_dim , ** dd ) if qk_norm else nn .Identity ()
172191 self .attn_drop = nn .Dropout (attn_drop )
173- self .norm = norm_layer (attn_dim , ** dd ) if scale_norm else nn .Identity ()
174- self .proj = nn .Linear (attn_dim , dim , bias = proj_bias , ** dd )
192+ self .norm = norm_layer (self . attn_dim , ** dd ) if scale_norm else nn .Identity ()
193+ self .proj = nn .Linear (self . attn_dim , dim_out , bias = proj_bias , ** dd )
175194 self .proj_drop = nn .Dropout (proj_drop )
176195
177196 def forward (
@@ -188,18 +207,18 @@ def forward(
188207 attn_mask: Optional attention mask to apply during attention computation
189208
190209 Returns:
191- Tensor of shape (batch_size, sequence_length, embedding_dim )
210+ Tensor of shape (batch_size, sequence_length, dim_out )
192211 """
193212 B , N , C = x .shape
194213
195214 if self .qkv is not None :
196215 qkv = self .qkv (x )
197- qkv = qkv .reshape (B , N , 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
216+ qkv = qkv .reshape (B , N , 3 , self .num_heads , self . head_dim ).permute (2 , 0 , 3 , 1 , 4 )
198217 q , k , v = qkv .unbind (0 ) # B, num_heads, N, head_dim
199218 else :
200- q = self .q_proj (x ).reshape (B , N , self .num_heads , - 1 ).transpose (1 , 2 ) # B, num_heads, N, C
201- k = self .k_proj (x ).reshape (B , N , self .num_heads , - 1 ).transpose (1 , 2 )
202- v = self .v_proj (x ).reshape (B , N , self .num_heads , - 1 ).transpose (1 , 2 )
219+ q = self .q_proj (x ).reshape (B , N , self .num_heads , self . head_dim ).transpose (1 , 2 )
220+ k = self .k_proj (x ).reshape (B , N , self .num_heads , self . head_dim ).transpose (1 , 2 )
221+ v = self .v_proj (x ).reshape (B , N , self .num_heads , self . head_dim ).transpose (1 , 2 )
203222
204223 q , k = self .q_norm (q ), self .k_norm (k )
205224
@@ -224,7 +243,7 @@ def forward(
224243 attn = self .attn_drop (attn )
225244 x = attn @ v
226245
227- x = x .transpose (1 , 2 ).reshape (B , N , C )
246+ x = x .transpose (1 , 2 ).reshape (B , N , self . attn_dim )
228247 x = self .norm (x )
229248 x = self .proj (x )
230249 x = self .proj_drop (x )
0 commit comments