diff --git a/autoencoder.py b/autoencoder.py index 7debb21..f221a0e 100644 --- a/autoencoder.py +++ b/autoencoder.py @@ -468,6 +468,8 @@ def __init__( dim_head=64, dropout=0.0, backend=None, + layer=0, + **kwargs ): super().__init__() inner_dim = dim_head * heads @@ -1097,7 +1099,8 @@ def __init__( ): super().__init__() assert attn_mode in self.ATTENTION_MODES - if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: + # linear attention doesn't depend on xformers + if attn_mode == "softmax-xformers" and not XFORMERS_IS_AVAILABLE: print( f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"