diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 6f0da73e7b..03b722a731 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -79,16 +79,16 @@ def __init__( ) self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention - - self.norm_cross_attn = nn.LayerNorm(hidden_size) - self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=False, - use_flash_attention=use_flash_attention, - ) + if with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, + ) def forward( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None