diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c720b379551f..2557cea67857 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,19 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from ..utils import deprecate, logging +from ..utils import logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available -from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU -from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 -from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +from .attention_processor import Attention, AttentionProcessor # noqa if is_xformers_available(): @@ -509,1224 +505,132 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor return encoder_hidden_states -def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): - # "feed_forward_chunk_size" can be used to save memory - if hidden_states.shape[chunk_dim] % chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = hidden_states.shape[chunk_dim] // chunk_size - ff_output = torch.cat( - [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], - dim=chunk_dim, +def _chunked_feed_forward(*args, **kwargs): + """Backward compatibility stub. Use transformers.modeling_common._chunked_feed_forward instead.""" + logger.warning( + "Importing `_chunked_feed_forward` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import _chunked_feed_forward` instead." ) - return ff_output - - -@maybe_allow_in_graph -class GatedSelfAttentionDense(nn.Module): - r""" - A gated self-attention dense layer that combines visual features and object features. - - Parameters: - query_dim (`int`): The number of channels in the query. - context_dim (`int`): The number of channels in the context. - n_heads (`int`): The number of heads to use for attention. - d_head (`int`): The number of channels in each head. - """ - - def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): - super().__init__() - - # we need a linear projection since we need cat visual feature and obj feature - self.linear = nn.Linear(context_dim, query_dim) - - self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) - self.ff = FeedForward(query_dim, activation_fn="geglu") - - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) - - self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) - self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + from .transformers.modeling_common import _chunked_feed_forward as _actual_chunked_feed_forward - self.enabled = True + return _actual_chunked_feed_forward(*args, **kwargs) - def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: - if not self.enabled: - return x - n_visual = x.shape[1] - objs = self.linear(objs) - - x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] - x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) - - return x - - -@maybe_allow_in_graph -class JointTransformerBlock(nn.Module): +class GatedSelfAttentionDense: r""" - A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. - - Reference: https://huggingface.co/papers/2403.03206 - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. + Backward compatibility stub. Use transformers.modeling_common.GatedSelfAttentionDense instead. """ - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - context_pre_only: bool = False, - qk_norm: Optional[str] = None, - use_dual_attention: bool = False, - ): - super().__init__() - - self.use_dual_attention = use_dual_attention - self.context_pre_only = context_pre_only - context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - - if use_dual_attention: - self.norm1 = SD35AdaLayerNormZeroX(dim) - else: - self.norm1 = AdaLayerNormZero(dim) - - if context_norm_type == "ada_norm_continous": - self.norm1_context = AdaLayerNormContinuous( - dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" - ) - elif context_norm_type == "ada_norm_zero": - self.norm1_context = AdaLayerNormZero(dim) - else: - raise ValueError( - f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" - ) - - if hasattr(F, "scaled_dot_product_attention"): - processor = JointAttnProcessor2_0() - else: - raise ValueError( - "The current PyTorch version does not support the `scaled_dot_product_attention` function." - ) - - self.attn = Attention( - query_dim=dim, - cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=context_pre_only, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=1e-6, - ) - - if use_dual_attention: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, - processor=processor, - qk_norm=qk_norm, - eps=1e-6, - ) - else: - self.attn2 = None - - self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - if not context_pre_only: - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - else: - self.norm2_context = None - self.ff_context = None - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - joint_attention_kwargs = joint_attention_kwargs or {} - if self.use_dual_attention: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( - hidden_states, emb=temb - ) - else: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - - if self.context_pre_only: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) - else: - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb - ) - - # Attention. - attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - **joint_attention_kwargs, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `GatedSelfAttentionDense` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import GatedSelfAttentionDense` instead." ) + from .transformers.modeling_common import GatedSelfAttentionDense - # Process attention outputs for the `hidden_states`. - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = hidden_states + attn_output - - if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) - attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 - hidden_states = hidden_states + attn_output2 - - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = hidden_states + ff_output + return GatedSelfAttentionDense(*args, **kwargs) - # Process attention outputs for the `encoder_hidden_states`. - if self.context_pre_only: - encoder_hidden_states = None - else: - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - context_ff_output = _chunked_feed_forward( - self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size - ) - else: - context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - return encoder_hidden_states, hidden_states - - -@maybe_allow_in_graph -class BasicTransformerBlock(nn.Module): +class JointTransformerBlock: r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. + Backward compatibility stub. Use transformers.modeling_common.JointTransformerBlock instead. """ - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' - norm_eps: float = 1e-5, - final_dropout: bool = False, - attention_type: str = "default", - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, - ada_norm_bias: Optional[int] = None, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - ): - super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.dropout = dropout - self.cross_attention_dim = cross_attention_dim - self.activation_fn = activation_fn - self.attention_bias = attention_bias - self.double_self_attention = double_self_attention - self.norm_elementwise_affine = norm_elementwise_affine - self.positional_embeddings = positional_embeddings - self.num_positional_embeddings = num_positional_embeddings - self.only_cross_attention = only_cross_attention - - # We keep these boolean flags for backward-compatibility. - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - self.norm_type = norm_type - self.num_embeds_ada_norm = num_embeds_ada_norm - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if norm_type == "ada_norm": - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_zero": - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_continuous": - self.norm1 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `JointTransformerBlock` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import JointTransformerBlock` instead." ) + from .transformers.modeling_common import JointTransformerBlock - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - if norm_type == "ada_norm": - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif norm_type == "ada_norm_continuous": - self.norm2 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) # is self-attn if encoder_hidden_states is none - else: - if norm_type == "ada_norm_single": # For Latte - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - if norm_type == "ada_norm_continuous": - self.norm3 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "layer_norm", - ) - - elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: - self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - elif norm_type == "layer_norm_i2vgen": - self.norm3 = None - - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - # 4. Fuser - if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) - - # 5. Scale-shift for PixArt-Alpha. - if norm_type == "ada_norm_single": - self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] - - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type == "ada_norm_single": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) - norm_hidden_states = self.norm1(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - else: - raise ValueError("Incorrect norm used") - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.norm_type == "ada_norm_zero": - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.norm_type == "ada_norm_single": - attn_output = gate_msa * attn_output - - hidden_states = attn_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - - # 3. Cross-Attention - if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) - else: - raise ValueError("Incorrect norm") - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) + return JointTransformerBlock(*args, **kwargs) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - # i2vgen doesn't have this norm 🤷‍♂️ - if self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif not self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm3(hidden_states) - - if self.norm_type == "ada_norm_zero": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - if self.norm_type == "ada_norm_zero": - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.norm_type == "ada_norm_single": - ff_output = gate_mlp * ff_output - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states - - -class LuminaFeedForward(nn.Module): +class BasicTransformerBlock: r""" - A feed-forward layer. - - Parameters: - hidden_size (`int`): - The dimensionality of the hidden layers in the model. This parameter determines the width of the model's - hidden representations. - intermediate_size (`int`): The intermediate dimension of the feedforward layer. - multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple - of this value. - ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden - dimension. Defaults to None. + Backward compatibility stub. Use transformers.modeling_common.BasicTransformerBlock instead. """ - def __init__( - self, - dim: int, - inner_dim: int, - multiple_of: Optional[int] = 256, - ffn_dim_multiplier: Optional[float] = None, - ): - super().__init__() - # custom hidden_size factor multiplier - if ffn_dim_multiplier is not None: - inner_dim = int(ffn_dim_multiplier * inner_dim) - inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) - - self.linear_1 = nn.Linear( - dim, - inner_dim, - bias=False, - ) - self.linear_2 = nn.Linear( - inner_dim, - dim, - bias=False, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `BasicTransformerBlock` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import BasicTransformerBlock` instead." ) - self.linear_3 = nn.Linear( - dim, - inner_dim, - bias=False, - ) - self.silu = FP32SiLU() + from .transformers.modeling_common import BasicTransformerBlock - def forward(self, x): - return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + return BasicTransformerBlock(*args, **kwargs) -@maybe_allow_in_graph -class TemporalBasicTransformerBlock(nn.Module): +class LuminaFeedForward: r""" - A basic Transformer block for video like data. - - Parameters: - dim (`int`): The number of channels in the input and output. - time_mix_inner_dim (`int`): The number of channels for temporal attention. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + Backward compatibility stub. Use transformers.modeling_common.LuminaFeedForward instead. """ - def __init__( - self, - dim: int, - time_mix_inner_dim: int, - num_attention_heads: int, - attention_head_dim: int, - cross_attention_dim: Optional[int] = None, - ): - super().__init__() - self.is_res = dim == time_mix_inner_dim - - self.norm_in = nn.LayerNorm(dim) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.ff_in = FeedForward( - dim, - dim_out=time_mix_inner_dim, - activation_fn="geglu", - ) - - self.norm1 = nn.LayerNorm(time_mix_inner_dim) - self.attn1 = Attention( - query_dim=time_mix_inner_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - cross_attention_dim=None, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `LuminaFeedForward` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import LuminaFeedForward` instead." ) + from .transformers.modeling_common import LuminaFeedForward - # 2. Cross-Attn - if cross_attention_dim is not None: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = nn.LayerNorm(time_mix_inner_dim) - self.attn2 = Attention( - query_dim=time_mix_inner_dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(time_mix_inner_dim) - self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = None - - def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): - # Sets chunk feed-forward - self._chunk_size = chunk_size - # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off - self._chunk_dim = 1 - - def forward( - self, - hidden_states: torch.Tensor, - num_frames: int, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Notice that normalization is always applied before the real computation in the following blocks. - # 0. Self-Attention - batch_size = hidden_states.shape[0] + return LuminaFeedForward(*args, **kwargs) - batch_frames, seq_length, channels = hidden_states.shape - batch_size = batch_frames // num_frames - hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) - - residual = hidden_states - hidden_states = self.norm_in(hidden_states) - - if self._chunk_size is not None: - hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) - else: - hidden_states = self.ff_in(hidden_states) - - if self.is_res: - hidden_states = hidden_states + residual - - norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) - hidden_states = attn_output + hidden_states - - # 3. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) - hidden_states = attn_output + hidden_states - - # 4. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - if self.is_res: - hidden_states = ff_output + hidden_states - else: - hidden_states = ff_output - - hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) - hidden_states = hidden_states.permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) - - return hidden_states - - -class SkipFFTransformerBlock(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - kv_input_dim: int, - kv_input_dim_proj_use_bias: bool, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - attention_out_bias: bool = True, - ): - super().__init__() - if kv_input_dim != dim: - self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) - else: - self.kv_mapper = None - - self.norm1 = RMSNorm(dim, 1e-06) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim, - out_bias=attention_out_bias, - ) - - self.norm2 = RMSNorm(dim, 1e-06) +class TemporalBasicTransformerBlock: + r""" + Backward compatibility stub. Use transformers.modeling_common.TemporalBasicTransformerBlock instead. + """ - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - out_bias=attention_out_bias, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `TemporalBasicTransformerBlock` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import TemporalBasicTransformerBlock` instead." ) + from .transformers.modeling_common import TemporalBasicTransformerBlock - def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - - if self.kv_mapper is not None: - encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + return TemporalBasicTransformerBlock(*args, **kwargs) - norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - - hidden_states = attn_output + hidden_states - - norm_hidden_states = self.norm2(hidden_states) +class SkipFFTransformerBlock: + r""" + Backward compatibility stub. Use transformers.modeling_common.SkipFFTransformerBlock instead. + """ - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `SkipFFTransformerBlock` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import SkipFFTransformerBlock` instead." ) + from .transformers.modeling_common import SkipFFTransformerBlock - hidden_states = attn_output + hidden_states - - return hidden_states + return SkipFFTransformerBlock(*args, **kwargs) -@maybe_allow_in_graph -class FreeNoiseTransformerBlock(nn.Module): +class FreeNoiseTransformerBlock: r""" - A FreeNoise Transformer block. - - Parameters: - dim (`int`): - The number of channels in the input and output. - num_attention_heads (`int`): - The number of heads to use for multi-head attention. - attention_head_dim (`int`): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - cross_attention_dim (`int`, *optional*): - The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): - Activation function to be used in feed-forward. - num_embeds_ada_norm (`int`, *optional*): - The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (`bool`, defaults to `False`): - Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, defaults to `False`): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, defaults to `False`): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, defaults to `False`): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` defaults to `False`): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. - ff_inner_dim (`int`, *optional*): - Hidden dimension of feed-forward MLP. - ff_bias (`bool`, defaults to `True`): - Whether or not to use bias in feed-forward MLP. - attention_out_bias (`bool`, defaults to `True`): - Whether or not to use bias in attention output project layer. - context_length (`int`, defaults to `16`): - The maximum number of frames that the FreeNoise block processes at once. - context_stride (`int`, defaults to `4`): - The number of frames to be skipped before starting to process a new batch of `context_length` frames. - weighting_scheme (`str`, defaults to `"pyramid"`): - The weighting scheme to use for weighting averaging of processed latent frames. As described in the - Equation 9. of the [FreeNoise](https://huggingface.co/papers/2310.15169) paper, "pyramid" is the default - setting used. + Backward compatibility stub. Use transformers.modeling_common.FreeNoiseTransformerBlock instead. """ - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - norm_eps: float = 1e-5, - final_dropout: bool = False, - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ff_inner_dim: Optional[int] = None, - ff_bias: bool = True, - attention_out_bias: bool = True, - context_length: int = 16, - context_stride: int = 4, - weighting_scheme: str = "pyramid", - ): - super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.dropout = dropout - self.cross_attention_dim = cross_attention_dim - self.activation_fn = activation_fn - self.attention_bias = attention_bias - self.double_self_attention = double_self_attention - self.norm_elementwise_affine = norm_elementwise_affine - self.positional_embeddings = positional_embeddings - self.num_positional_embeddings = num_positional_embeddings - self.only_cross_attention = only_cross_attention - - self.set_free_noise_properties(context_length, context_stride, weighting_scheme) - - # We keep these boolean flags for backward-compatibility. - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - self.norm_type = norm_type - self.num_embeds_ada_norm = num_embeds_ada_norm - - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `FreeNoiseTransformerBlock` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import FreeNoiseTransformerBlock` instead." ) + from .transformers.modeling_common import FreeNoiseTransformerBlock - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - ) # is self-attn if encoder_hidden_states is none - - # 3. Feed-forward - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: - frame_indices = [] - for i in range(0, num_frames - self.context_length + 1, self.context_stride): - window_start = i - window_end = min(num_frames, i + self.context_length) - frame_indices.append((window_start, window_end)) - return frame_indices - - def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: - if weighting_scheme == "flat": - weights = [1.0] * num_frames - - elif weighting_scheme == "pyramid": - if num_frames % 2 == 0: - # num_frames = 4 => [1, 2, 2, 1] - mid = num_frames // 2 - weights = list(range(1, mid + 1)) - weights = weights + weights[::-1] - else: - # num_frames = 5 => [1, 2, 3, 2, 1] - mid = (num_frames + 1) // 2 - weights = list(range(1, mid)) - weights = weights + [mid] + weights[::-1] - - elif weighting_scheme == "delayed_reverse_sawtooth": - if num_frames % 2 == 0: - # num_frames = 4 => [0.01, 2, 2, 1] - mid = num_frames // 2 - weights = [0.01] * (mid - 1) + [mid] - weights = weights + list(range(mid, 0, -1)) - else: - # num_frames = 5 => [0.01, 0.01, 3, 2, 1] - mid = (num_frames + 1) // 2 - weights = [0.01] * mid - weights = weights + list(range(mid, 0, -1)) - else: - raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") - - return weights - - def set_free_noise_properties( - self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" - ) -> None: - self.context_length = context_length - self.context_stride = context_stride - self.weighting_scheme = weighting_scheme + return FreeNoiseTransformerBlock(*args, **kwargs) - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - *args, - **kwargs, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - - # hidden_states: [B x H x W, F, C] - device = hidden_states.device - dtype = hidden_states.dtype - - num_frames = hidden_states.size(1) - frame_indices = self._get_frame_indices(num_frames) - frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) - frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) - is_last_frame_batch_complete = frame_indices[-1][1] == num_frames - - # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length - # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: - # [(0, 16), (4, 20), (8, 24), (10, 26)] - if not is_last_frame_batch_complete: - if num_frames < self.context_length: - raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") - last_frame_batch_length = num_frames - frame_indices[-1][1] - frame_indices.append((num_frames - self.context_length, num_frames)) - - num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) - accumulated_values = torch.zeros_like(hidden_states) - - for i, (frame_start, frame_end) in enumerate(frame_indices): - # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle - # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or - # essentially a non-multiple of `context_length`. - weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) - weights *= frame_weights - - hidden_states_chunk = hidden_states[:, frame_start:frame_end] - - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - norm_hidden_states = self.norm1(hidden_states_chunk) - - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - hidden_states_chunk = attn_output + hidden_states_chunk - if hidden_states_chunk.ndim == 4: - hidden_states_chunk = hidden_states_chunk.squeeze(1) - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = self.norm2(hidden_states_chunk) - - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states_chunk = attn_output + hidden_states_chunk - - if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: - accumulated_values[:, -last_frame_batch_length:] += ( - hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] - ) - num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] - else: - accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights - num_times_accumulated[:, frame_start:frame_end] += weights - - # TODO(aryan): Maybe this could be done in a better way. - # - # Previously, this was: - # hidden_states = torch.where( - # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values - # ) - # - # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory - # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes - # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly - # looked into this deeply because other memory optimizations led to more pronounced reductions. - hidden_states = torch.cat( - [ - torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) - for accumulated_split, num_times_split in zip( - accumulated_values.split(self.context_length, dim=1), - num_times_accumulated.split(self.context_length, dim=1), - ) - ], - dim=1, - ).to(dtype) - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self._chunk_size is not None: - ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) - else: - ff_output = self.ff(norm_hidden_states) - - hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - - return hidden_states - - -class FeedForward(nn.Module): +class FeedForward: r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + Backward compatibility stub. Use transformers.modeling_common.FeedForward instead. """ - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - inner_dim=None, - bias: bool = True, - ): - super().__init__() - if inner_dim is None: - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim, bias=bias) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim, bias=bias) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim, bias=bias) - elif activation_fn == "swiglu": - act_fn = SwiGLU(dim, inner_dim, bias=bias) - elif activation_fn == "linear-silu": - act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states + def __new__(cls, *args, **kwargs): + logger.warning( + "Importing `FeedForward` from `diffusers.models.attention` is deprecated and will be removed in a future version. " + "Please use `from diffusers.models.transformers.modeling_common import FeedForward` instead." + ) + from .transformers.modeling_common import FeedForward + + return FeedForward(*args, **kwargs) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..96a29a4dc514 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -10,6 +10,17 @@ from .hunyuan_transformer_2d import HunyuanDiT2DModel from .latte_transformer_3d import LatteTransformer3DModel from .lumina_nextdit2d import LuminaNextDiT2DModel + from .modeling_common import ( + BasicTransformerBlock, + FeedForward, + FreeNoiseTransformerBlock, + GatedSelfAttentionDense, + JointTransformerBlock, + LuminaFeedForward, + SkipFFTransformerBlock, + TemporalBasicTransformerBlock, + _chunked_feed_forward, + ) from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer from .sana_transformer import SanaTransformer2DModel diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index a8c98bccb86c..354619c00d24 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -22,13 +22,18 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward -from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from ..attention_processor import ( + Attention, + AttentionProcessor, + CogVideoXAttnProcessor2_0, + FusedCogVideoXAttnProcessor2_0, +) from ..cache_utils import CacheMixin from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index 41632dbd4751..fb5254c452a3 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -22,12 +22,12 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward -from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0 from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 68f6f769436e..989bf4787ac8 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -19,10 +19,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import BasicTransformerBlock from ..embeddings import PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin +from .modeling_common import BasicTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index f63471878857..722a98cb9b31 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -19,7 +19,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 from ..embeddings import ( HunyuanCombinedTimestepTextSizeStyleEmbedding, @@ -29,6 +28,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, FP32LayerNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 990c90512e39..b623b1f3955e 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -18,12 +18,12 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ..attention import BasicTransformerBlock from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import BasicTransformerBlock class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index 84b1175386b0..b5763b33aa56 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -19,7 +19,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import LuminaFeedForward from ..attention_processor import Attention, LuminaAttnProcessor2_0 from ..embeddings import ( LuminaCombinedTimestepCaptionEmbedding, @@ -28,6 +27,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm +from .modeling_common import LuminaFeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/modeling_common.py b/src/diffusers/models/transformers/modeling_common.py new file mode 100644 index 000000000000..f4baff19b1b1 --- /dev/null +++ b/src/diffusers/models/transformers/modeling_common.py @@ -0,0 +1,1258 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...utils import deprecate, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU +from ..attention_processor import Attention, JointAttnProcessor2_0 +from ..embeddings import SinusoidalPositionalEmbedding +from ..normalization import ( + AdaLayerNorm, + AdaLayerNormContinuous, + AdaLayerNormZero, + RMSNorm, + SD35AdaLayerNormZeroX, +) + + +logger = logging.get_logger(__name__) + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class JointTransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://huggingface.co/papers/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, + ): + super().__init__() + + self.use_dual_attention = use_dual_attention + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" + + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" + ) + + if hasattr(F, "scaled_dot_product_attention"): + processor = JointAttnProcessor2_0() + else: + raise ValueError( + "The current PyTorch version does not support the `scaled_dot_product_attention` function." + ) + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=context_pre_only, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + + if use_dual_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + joint_attention_kwargs = joint_attention_kwargs or {} + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + context_ff_output = _chunked_feed_forward( + self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.silu = FP32SiLU() + + def forward(self, x): + return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x)) + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: int, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +@maybe_allow_in_graph +class FreeNoiseTransformerBlock(nn.Module): + r""" + A FreeNoise Transformer block. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (`int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (`bool`, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, defaults to `False`): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, defaults to `False`): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, defaults to `False`): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*): + Hidden dimension of feed-forward MLP. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in feed-forward MLP. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in attention output project layer. + context_length (`int`, defaults to `16`): + The maximum number of frames that the FreeNoise block processes at once. + context_stride (`int`, defaults to `4`): + The number of frames to be skipped before starting to process a new batch of `context_length` frames. + weighting_scheme (`str`, defaults to `"pyramid"`): + The weighting scheme to use for weighting averaging of processed latent frames. As described in the + Equation 9. of the [FreeNoise](https://huggingface.co/papers/2310.15169) paper, "pyramid" is the default + setting used. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + context_length: int = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + self.set_free_noise_properties(context_length, context_stride, weighting_scheme) + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: + frame_indices = [] + for i in range(0, num_frames - self.context_length + 1, self.context_stride): + window_start = i + window_end = min(num_frames, i + self.context_length) + frame_indices.append((window_start, window_end)) + return frame_indices + + def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: + if weighting_scheme == "flat": + weights = [1.0] * num_frames + + elif weighting_scheme == "pyramid": + if num_frames % 2 == 0: + # num_frames = 4 => [1, 2, 2, 1] + mid = num_frames // 2 + weights = list(range(1, mid + 1)) + weights = weights + weights[::-1] + else: + # num_frames = 5 => [1, 2, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = list(range(1, mid)) + weights = weights + [mid] + weights[::-1] + + elif weighting_scheme == "delayed_reverse_sawtooth": + if num_frames % 2 == 0: + # num_frames = 4 => [0.01, 2, 2, 1] + mid = num_frames // 2 + weights = [0.01] * (mid - 1) + [mid] + weights = weights + list(range(mid, 0, -1)) + else: + # num_frames = 5 => [0.01, 0.01, 3, 2, 1] + mid = (num_frames + 1) // 2 + weights = [0.01] * mid + weights = weights + list(range(mid, 0, -1)) + else: + raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") + + return weights + + def set_free_noise_properties( + self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" + ) -> None: + self.context_length = context_length + self.context_stride = context_stride + self.weighting_scheme = weighting_scheme + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + # hidden_states: [B x H x W, F, C] + device = hidden_states.device + dtype = hidden_states.dtype + + num_frames = hidden_states.size(1) + frame_indices = self._get_frame_indices(num_frames) + frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) + frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + is_last_frame_batch_complete = frame_indices[-1][1] == num_frames + + # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length + # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: + # [(0, 16), (4, 20), (8, 24), (10, 26)] + if not is_last_frame_batch_complete: + if num_frames < self.context_length: + raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") + last_frame_batch_length = num_frames - frame_indices[-1][1] + frame_indices.append((num_frames - self.context_length, num_frames)) + + num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) + accumulated_values = torch.zeros_like(hidden_states) + + for i, (frame_start, frame_end) in enumerate(frame_indices): + # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle + # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or + # essentially a non-multiple of `context_length`. + weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) + weights *= frame_weights + + hidden_states_chunk = hidden_states[:, frame_start:frame_end] + + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states_chunk) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states_chunk = attn_output + hidden_states_chunk + if hidden_states_chunk.ndim == 4: + hidden_states_chunk = hidden_states_chunk.squeeze(1) + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states_chunk) + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states_chunk = attn_output + hidden_states_chunk + + if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: + accumulated_values[:, -last_frame_batch_length:] += ( + hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] + ) + num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] + else: + accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start:frame_end] += weights + + # TODO(aryan): Maybe this could be done in a better way. + # + # Previously, this was: + # hidden_states = torch.where( + # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + # ) + # + # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory + # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes + # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly + # looked into this deeply because other memory optimizations led to more pronounced reductions. + hidden_states = torch.cat( + [ + torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split) + for accumulated_split, num_times_split in zip( + accumulated_values.split(self.context_length, dim=1), + num_times_accumulated.split(self.context_length, dim=1), + ) + ], + dim=1, + ).to(dtype) + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 40a14bfd9b27..5b13c24c69ca 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -18,12 +18,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import BasicTransformerBlock from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import BasicTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py index 565da0da8b6e..3c2a5afd6b71 100644 --- a/src/diffusers/models/transformers/prior_transformer.py +++ b/src/diffusers/models/transformers/prior_transformer.py @@ -8,7 +8,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...utils import BaseOutput -from ..attention import BasicTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -18,6 +17,7 @@ ) from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin +from .modeling_common import BasicTransformerBlock @dataclass diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index 969e6db122d9..68caff835ce4 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -23,10 +23,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0 from ..modeling_utils import ModelMixin from ..transformers.transformer_2d import Transformer2DModelOutput +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 67fe9a33109b..97376e9415b2 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -19,11 +19,11 @@ from ...configuration_utils import LegacyConfigMixin, register_to_config from ...utils import deprecate, logging -from ..attention import BasicTransformerBlock from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import LegacyModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import BasicTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 5fa59a71d977..567d2aeebafa 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -22,13 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import AllegroAttnProcessor2_0, Attention from ..cache_utils import CacheMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle +from .modeling_common import FeedForward logger = logging.get_logger(__name__) diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index 27a9941501a1..cd2c34cc6fba 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -10,13 +10,14 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionModuleMixin, FeedForward +from ..attention import AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 5823ae9d3da6..8067d60f4f7b 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -24,12 +24,13 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm +from .modeling_common import FeedForward from .transformer_flux import FluxAttention, FluxAttnProcessor diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 77f15f6ca6f1..6a8583b6da4e 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -20,12 +20,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0 from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 25dcfa14cc0b..0399b15c015a 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -22,13 +22,13 @@ from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LayerNorm, RMSNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 373b470ae37b..08106852a7d4 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -22,12 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin from ...utils import is_torchvision_available -from ..attention import FeedForward from ..attention_processor import Attention from ..embeddings import Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm +from .modeling_common import FeedForward if is_torchvision_available(): diff --git a/src/diffusers/models/transformers/transformer_easyanimate.py b/src/diffusers/models/transformers/transformer_easyanimate.py index 545fa29730db..6b091428ec88 100755 --- a/src/diffusers/models/transformers/transformer_easyanimate.py +++ b/src/diffusers/models/transformers/transformer_easyanimate.py @@ -22,11 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, FeedForward +from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 60c7eb1dbabe..c5c59dacd0d3 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -25,7 +25,7 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( @@ -37,6 +37,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 77902dcf5852..8c0c30efaeab 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -10,7 +10,7 @@ from ...models.modeling_utils import ModelMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention +from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 6944a6c536b5..d7d7b8d73452 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -23,7 +23,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward from ..attention_processor import Attention, AttentionProcessor from ..cache_utils import CacheMixin from ..embeddings import ( @@ -36,6 +35,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 79149fb76067..2d6637a9fb03 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -24,13 +24,14 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 77121edb9fc9..1519ac3bb699 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -23,12 +23,12 @@ from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import LuminaFeedForward from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm +from .modeling_common import LuminaFeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 63911fe7c10d..8cb48d8265c1 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -23,13 +23,13 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 from ..cache_utils import CacheMixin from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 846add8906ac..6ac4c29a1eb4 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -25,7 +25,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..cache_utils import CacheMixin @@ -33,6 +33,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, RMSNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index edf77a7df793..26ab1faf6441 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -20,7 +20,6 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import FeedForward, JointTransformerBlock from ..attention_processor import ( Attention, AttentionProcessor, @@ -31,6 +30,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero +from .modeling_common import FeedForward, JointTransformerBlock logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 236fca690a90..15c7131700b2 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -22,7 +22,6 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import ( @@ -34,6 +33,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype from ..normalization import FP32LayerNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index ffaf31d04570..011674429159 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -19,10 +19,10 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput -from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock from ..embeddings import TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..resnet import AlphaBlender +from .modeling_common import BasicTransformerBlock, TemporalBasicTransformerBlock @dataclass diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 968a0369c243..f899748103f9 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -23,13 +23,14 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from .modeling_common import FeedForward logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index e039d362193d..f52aa70cd6e1 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -21,11 +21,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from .modeling_common import FeedForward from .transformer_wan import ( WanAttention, WanAttnProcessor,