Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,256 changes: 80 additions & 1,176 deletions src/diffusers/models/attention.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
from .hunyuan_transformer_2d import HunyuanDiT2DModel
from .latte_transformer_3d import LatteTransformer3DModel
from .lumina_nextdit2d import LuminaNextDiT2DModel
from .modeling_common import (
BasicTransformerBlock,
FeedForward,
FreeNoiseTransformerBlock,
GatedSelfAttentionDense,
JointTransformerBlock,
LuminaFeedForward,
SkipFFTransformerBlock,
TemporalBasicTransformerBlock,
_chunked_feed_forward,
)
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .sana_transformer import SanaTransformer2DModel
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..attention_processor import (
Attention,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
FusedCogVideoXAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
from .modeling_common import FeedForward


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/consisid_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
from .modeling_common import FeedForward


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from .modeling_common import BasicTransformerBlock


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
Expand All @@ -29,6 +28,7 @@
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
from .modeling_common import FeedForward


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from torch import nn

from ...configuration_utils import ConfigMixin, register_to_config
from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from .modeling_common import BasicTransformerBlock


class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/lumina_nextdit2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention, LuminaAttnProcessor2_0
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding,
Expand All @@ -28,6 +27,7 @@
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
from .modeling_common import LuminaFeedForward


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down
Loading
Loading