diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index 07beed462032..fe91b2f8182d 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -42,7 +42,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -83,11 +82,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -101,38 +100,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -148,7 +122,7 @@ def eager_attention_forward( class DummyBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -167,12 +141,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -210,11 +178,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -225,8 +188,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -234,7 +195,7 @@ def forward( class DummyBertCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -253,12 +214,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -300,11 +255,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -315,8 +265,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -338,15 +286,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class DummyBertAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = DummyBertCrossAttention if is_cross_attention else DummyBertSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = DummyBertSelfOutput(config) self.pruned_heads = set() @@ -433,7 +377,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = DummyBertAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -638,8 +581,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = DummyBertPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index 427e8f8d1572..351272a418e1 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -44,7 +44,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -86,11 +85,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -104,38 +103,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -151,7 +125,7 @@ def eager_attention_forward( class RobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -170,12 +144,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -213,11 +181,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -228,8 +191,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -237,7 +198,7 @@ def forward( class RobertaCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -256,12 +217,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -303,11 +258,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -318,8 +268,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -341,15 +289,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RobertaAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = RobertaSelfOutput(config) self.pruned_heads = set() @@ -436,7 +380,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = RobertaAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -641,8 +584,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = RobertaPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index b60c19d504f0..110ad8cb138d 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -68,12 +68,6 @@ class AlbertConfig(PretrainedConfig): The epsilon used by the layer normalization layers. classifier_dropout_prob (`float`, *optional*, defaults to 0.1): The dropout ratio for attached classifiers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). pad_token_id (`int`, *optional*, defaults to 0): Padding token id. bos_token_id (`int`, *optional*, defaults to 2): @@ -123,7 +117,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, classifier_dropout_prob=0.1, - position_embedding_type="absolute", pad_token_id=0, bos_token_id=2, eos_token_id=3, @@ -147,7 +140,6 @@ def __init__( self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.classifier_dropout_prob = classifier_dropout_prob - self.position_embedding_type = position_embedding_type # Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index a1dfa5e2fc9c..d4f07251ecb1 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -69,7 +69,6 @@ def __init__(self, config: AlbertConfig): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False ) @@ -106,11 +105,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -125,38 +124,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -198,11 +172,6 @@ def __init__(self, config: AlbertConfig): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pruned_heads = set() - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - self.is_causal = False def prune_heads(self, heads: list[int]) -> None: @@ -239,11 +208,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -254,8 +218,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=False, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -429,7 +391,6 @@ def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True): self.pooler_activation = None self.attn_implementation = config._attn_implementation - self.position_embedding_type = config.position_embedding_type # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/align/configuration_align.py b/src/transformers/models/align/configuration_align.py index b924d85a6ca6..3eaeffdc6f23 100644 --- a/src/transformers/models/align/configuration_align.py +++ b/src/transformers/models/align/configuration_align.py @@ -62,12 +62,6 @@ class AlignTextConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -105,7 +99,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, **kwargs, ): @@ -123,7 +116,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.pad_token_id = pad_token_id diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index f55c84b47176..9c0d09333325 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -519,7 +519,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -558,11 +557,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/altclip/configuration_altclip.py b/src/transformers/models/altclip/configuration_altclip.py index 474fc48081b5..f03f1c3167b1 100755 --- a/src/transformers/models/altclip/configuration_altclip.py +++ b/src/transformers/models/altclip/configuration_altclip.py @@ -67,12 +67,6 @@ class AltCLIPTextConfig(PretrainedConfig): bos_token_id (`int`, *optional*, defaults to 0): The id of the *beginning-of-sequence* token. eos_token_id (`Union[int, list[int]]`, *optional*, defaults to 2): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -114,7 +108,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, project_dim=768, **kwargs, @@ -134,7 +127,6 @@ def __init__( self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.project_dim = project_dim diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index f40caa7af4ce..58bf63573e70 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -100,7 +100,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -152,11 +151,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -197,7 +196,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l class AltRobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -214,12 +213,6 @@ def __init__(self, config, position_embedding_type=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) def forward( self, @@ -237,23 +230,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in AltRobertaModel forward() function) @@ -298,11 +274,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class AltRobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type - ) + self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](config) self.output = AltRobertaSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index e7e51d3295ef..28367fbb8ff8 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -65,12 +65,6 @@ class BertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -111,7 +105,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -130,7 +123,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1689da04aa52..f9f211e3025a 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -67,7 +67,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -108,11 +107,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -126,38 +125,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -173,7 +147,7 @@ def eager_attention_forward( class BertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -192,12 +166,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -235,11 +203,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -250,8 +213,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -259,7 +220,7 @@ def forward( class BertCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -278,12 +239,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -325,11 +280,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -340,8 +290,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -363,15 +311,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = BertCrossAttention if is_cross_attention else BertSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = BertSelfOutput(config) self.pruned_heads = set() @@ -458,7 +402,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = BertAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -720,8 +663,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = BertPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/bert_generation/configuration_bert_generation.py b/src/transformers/models/bert_generation/configuration_bert_generation.py index e6cf054cc5e2..b604378418a6 100644 --- a/src/transformers/models/bert_generation/configuration_bert_generation.py +++ b/src/transformers/models/bert_generation/configuration_bert_generation.py @@ -60,12 +60,6 @@ class BertGenerationConfig(PretrainedConfig): Beginning of stream token id. eos_token_id (`int`, *optional*, defaults to 1): End of stream token id. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -103,7 +97,6 @@ def __init__( pad_token_id=0, bos_token_id=2, eos_token_id=1, - position_embedding_type="absolute", use_cache=True, **kwargs, ): @@ -120,7 +113,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 12aee8a014b3..70aa1f0806bc 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -70,38 +70,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -118,7 +93,7 @@ def eager_attention_forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->BertGeneration class BertGenerationSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -137,12 +112,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -180,11 +149,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -195,8 +159,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -205,7 +167,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->BertGeneration class BertGenerationCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -224,12 +186,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -271,11 +227,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -286,8 +237,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -296,15 +245,11 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION class BertGenerationAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = BertGenerationCrossAttention if is_cross_attention else BertGenerationSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = BertGenerationSelfOutput(config) self.pruned_heads = set() @@ -394,7 +339,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = BertGenerationAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f774c61c5964..86edfb5ed884 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -78,7 +78,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 18d4bdaab721..04b0f170ec7f 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -56,7 +56,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config @@ -82,9 +81,9 @@ def forward( embeddings = inputs_embeds - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -115,10 +114,6 @@ def __init__(self, config, is_cross_attention, layer_idx=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients @@ -198,22 +193,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 23145ffc543f..f95aa409e60c 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -146,12 +146,6 @@ class Blip2QFormerConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Index to be used for padding token. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). cross_attention_frequency (`int`, *optional*, defaults to 2): The frequency of adding cross-attention to the Transformer layers. encoder_hidden_size (`int`, *optional*, defaults to 1408): @@ -190,7 +184,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", cross_attention_frequency=2, encoder_hidden_size=1408, use_qformer_text_input=False, @@ -209,7 +202,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.cross_attention_frequency = cross_attention_frequency self.encoder_hidden_size = encoder_hidden_size self.use_qformer_text_input = use_qformer_text_input diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index cb4e36b37308..d45bde036528 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -546,10 +546,6 @@ def __init__(self, config, is_cross_attention=False): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): @@ -597,22 +593,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: @@ -866,7 +846,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") def forward( self, @@ -885,9 +864,9 @@ def forward( if input_ids is not None: input_ids = input_ids.to(self.word_embeddings.weight.device) embeddings = self.word_embeddings(input_ids) - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings if query_embeds is not None: # `query_embeds` are kept in fp32 when we use it with Qformer diff --git a/src/transformers/models/bridgetower/configuration_bridgetower.py b/src/transformers/models/bridgetower/configuration_bridgetower.py index 4c84b0a294da..fc363be9d57a 100644 --- a/src/transformers/models/bridgetower/configuration_bridgetower.py +++ b/src/transformers/models/bridgetower/configuration_bridgetower.py @@ -133,12 +133,6 @@ class BridgeTowerTextConfig(PretrainedConfig): testing). layer_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -177,7 +171,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, **kwargs, ): @@ -195,7 +188,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 896ee175c7b1..30196d35f1ab 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -415,38 +415,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -463,7 +438,7 @@ def eager_attention_forward( # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower class BridgeTowerSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -482,12 +457,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -525,11 +494,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -540,8 +504,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -550,7 +512,7 @@ def forward( # Copied from transformers.models.roberta.modeling_roberta.RobertaCrossAttention with Roberta->BridgeTower class BridgeTowerCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -569,12 +531,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -616,11 +572,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -631,8 +582,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -641,15 +590,11 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER class BridgeTowerAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = BridgeTowerCrossAttention if is_cross_attention else BridgeTowerSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = BridgeTowerSelfOutput(config) self.pruned_heads = set() @@ -704,7 +649,6 @@ def __init__(self, config, layer_idx=None): self.add_cross_attention = config.add_cross_attention self.crossattention = BridgeTowerAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -767,7 +711,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = BridgeTowerAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -895,7 +838,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -947,11 +889,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 517ff8b9b87a..ca323c8984cd 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -129,7 +129,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer( "token_type_ids", @@ -169,11 +168,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -197,10 +196,6 @@ def __init__(self, config): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder @@ -232,23 +227,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - # bbox positional encoding batch_size, n_head, seq_length, d_head = query_layer.shape bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head) diff --git a/src/transformers/models/camembert/configuration_camembert.py b/src/transformers/models/camembert/configuration_camembert.py index 3979e5487443..9f3b71da0904 100644 --- a/src/transformers/models/camembert/configuration_camembert.py +++ b/src/transformers/models/camembert/configuration_camembert.py @@ -65,12 +65,6 @@ class CamembertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -113,7 +107,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -132,7 +125,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index e5e361c9b7bb..b9d03d6657a4 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -65,38 +65,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -112,7 +87,7 @@ def eager_attention_forward( class CamembertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -131,12 +106,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -174,11 +143,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -189,8 +153,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -198,7 +160,7 @@ def forward( class CamembertCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -217,12 +179,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -264,11 +220,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -279,8 +230,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -302,15 +251,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class CamembertAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = CamembertCrossAttention if is_cross_attention else CamembertSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = CamembertSelfOutput(config) self.pruned_heads = set() @@ -397,7 +342,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = CamembertAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -526,7 +470,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -578,11 +521,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -700,8 +643,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = CamembertPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index e4ed912dd6b8..7479ae7c5681 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -106,7 +106,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int): """ @@ -171,12 +170,11 @@ def forward( ) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.char_position_embeddings(position_ids) - embeddings += position_embeddings + position_embeddings = self.char_position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -300,10 +298,6 @@ def __init__(self, config): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) def forward( self, @@ -337,22 +331,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = from_tensor.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=from_tensor.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=from_tensor.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: if attention_mask.ndim == 3: diff --git a/src/transformers/models/chinese_clip/configuration_chinese_clip.py b/src/transformers/models/chinese_clip/configuration_chinese_clip.py index 776df308a898..c763d4ea2bc2 100644 --- a/src/transformers/models/chinese_clip/configuration_chinese_clip.py +++ b/src/transformers/models/chinese_clip/configuration_chinese_clip.py @@ -75,12 +75,6 @@ class ChineseCLIPTextConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -119,7 +113,6 @@ def __init__( initializer_factor=1.0, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, **kwargs, ): @@ -138,7 +131,6 @@ def __init__( self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 9872b397b318..72b9171d60ab 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -101,7 +101,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -140,11 +139,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/clap/configuration_clap.py b/src/transformers/models/clap/configuration_clap.py index 900e8d373f5a..7e4ee719bc9a 100644 --- a/src/transformers/models/clap/configuration_clap.py +++ b/src/transformers/models/clap/configuration_clap.py @@ -58,12 +58,6 @@ class ClapTextConfig(PretrainedConfig): The vocabulary size of the `token_type_ids` passed when calling [`ClapTextModel`]. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -111,7 +105,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, projection_hidden_act="relu", **kwargs, @@ -130,7 +123,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_factor = initializer_factor self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.projection_hidden_act = projection_hidden_act self.projection_dim = projection_dim diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 885286ea3f49..67f4146a9980 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -971,7 +971,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True ) @@ -1023,11 +1022,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/data2vec/configuration_data2vec_text.py b/src/transformers/models/data2vec/configuration_data2vec_text.py index f9518d67bf66..8ab16458c5d8 100644 --- a/src/transformers/models/data2vec/configuration_data2vec_text.py +++ b/src/transformers/models/data2vec/configuration_data2vec_text.py @@ -64,12 +64,6 @@ class Data2VecTextConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -112,7 +106,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -131,7 +124,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 6ea41b626fff..66b83fd1bf5b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -67,7 +67,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -119,11 +118,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -171,38 +170,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -218,7 +192,7 @@ def eager_attention_forward( class Data2VecTextSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -237,12 +211,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -280,11 +248,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -295,8 +258,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -304,7 +265,7 @@ def forward( class Data2VecTextCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -323,12 +284,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -370,11 +325,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -385,8 +335,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -408,15 +356,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class Data2VecTextAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = Data2VecTextCrossAttention if is_cross_attention else Data2VecTextSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = Data2VecTextSelfOutput(config) self.pruned_heads = set() @@ -503,7 +447,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = Data2VecTextAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -660,8 +603,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = Data2VecTextPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/dpr/configuration_dpr.py b/src/transformers/models/dpr/configuration_dpr.py index 03b169002493..4b310b673f63 100644 --- a/src/transformers/models/dpr/configuration_dpr.py +++ b/src/transformers/models/dpr/configuration_dpr.py @@ -64,12 +64,6 @@ class DPRConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). projection_dim (`int`, *optional*, defaults to 0): Dimension of the projection for the context and question encoders. If it is set to zero (default), then no projection is done. @@ -106,7 +100,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", projection_dim: int = 0, **kwargs, ): @@ -125,7 +118,6 @@ def __init__( self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.projection_dim = projection_dim - self.position_embedding_type = position_embedding_type __all__ = ["DPRConfig"] diff --git a/src/transformers/models/electra/configuration_electra.py b/src/transformers/models/electra/configuration_electra.py index f12756d976b3..481925519c07 100644 --- a/src/transformers/models/electra/configuration_electra.py +++ b/src/transformers/models/electra/configuration_electra.py @@ -89,12 +89,6 @@ class ElectraConfig(PretrainedConfig): Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. The dropout ratio to be used after the projection and activation. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -138,7 +132,6 @@ def __init__( summary_activation="gelu", summary_last_dropout=0.1, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -163,7 +156,6 @@ def __init__( self.summary_use_proj = summary_use_proj self.summary_activation = summary_activation self.summary_last_dropout = summary_last_dropout - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 921e545afc35..7f166846fb86 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -74,7 +74,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False ) @@ -113,11 +112,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -132,38 +131,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -180,7 +154,7 @@ def eager_attention_forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra class ElectraSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -199,12 +173,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -242,11 +210,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -257,8 +220,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -267,7 +228,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->Electra class ElectraCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -286,12 +247,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -333,11 +288,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -348,8 +298,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -373,15 +321,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA class ElectraAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = ElectraCrossAttention if is_cross_attention else ElectraSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = ElectraSelfOutput(config) self.pruned_heads = set() @@ -471,7 +415,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = ElectraAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, diff --git a/src/transformers/models/ernie/configuration_ernie.py b/src/transformers/models/ernie/configuration_ernie.py index abf300f0ce51..e7f5bd1aff1a 100644 --- a/src/transformers/models/ernie/configuration_ernie.py +++ b/src/transformers/models/ernie/configuration_ernie.py @@ -71,12 +71,6 @@ class ErnieConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -117,7 +111,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -138,7 +131,6 @@ def __init__( self.use_task_id = use_task_id self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 01a5d9dddd2a..0978ab4554a0 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -71,7 +71,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -117,11 +116,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings # add `task_type_id` for ERNIE model if self.use_task_id: @@ -143,38 +141,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -190,7 +163,7 @@ def eager_attention_forward( class ErnieSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -209,12 +182,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -252,11 +219,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -267,8 +229,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -276,7 +236,7 @@ def forward( class ErnieCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -295,12 +255,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -342,11 +296,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -357,8 +306,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -380,15 +327,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class ErnieAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = ErnieCrossAttention if is_cross_attention else ErnieSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = ErnieSelfOutput(config) self.pruned_heads = set() @@ -475,7 +418,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = ErnieAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -682,8 +624,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = ErniePooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/ernie/modular_ernie.py b/src/transformers/models/ernie/modular_ernie.py index eba860e93185..b93adf8b7507 100644 --- a/src/transformers/models/ernie/modular_ernie.py +++ b/src/transformers/models/ernie/modular_ernie.py @@ -111,11 +111,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings # add `task_type_id` for ERNIE model if self.use_task_id: diff --git a/src/transformers/models/esm/configuration_esm.py b/src/transformers/models/esm/configuration_esm.py index fabfb4ebd6d3..afd5ee255ad0 100644 --- a/src/transformers/models/esm/configuration_esm.py +++ b/src/transformers/models/esm/configuration_esm.py @@ -67,11 +67,7 @@ class EsmConfig(PretrainedConfig): layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`. - For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). + Type of position embedding. Choose either `"absolute"` or "rotary"`. is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 76a0e178b0e2..6be0469b5cbe 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, @@ -33,11 +34,15 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_esm import EsmConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -251,44 +256,28 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): return position_ids.unsqueeze(0).expand(input_shape) +# Copied from transformers.models.bert.modeling_bert.eager_attention_forward def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): - # ESM applies relative position embeddings and we don't copy from Llama + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - if hasattr(module, "position_embedding_type") and module.position_embedding_type in [ - "relative_key", - "relative_key_query", - ]: - seq_length = query.shape[2] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - relative_position_scores = relative_position_scores_query + relative_position_scores_key - - attn_weights = attn_weights + relative_position_scores - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -317,14 +306,12 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cros self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = config.attention_probs_dropout_prob + + self.rotary_embeddings = None self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" ) - self.rotary_embeddings = None - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - elif self.position_embedding_type == "rotary": + if self.position_embedding_type == "rotary": self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) self.scaling = 1.0 # For BC we apply scaling before RoPE @@ -362,11 +349,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type in ["relative_key", "relative_key_query"]: - raise ValueError( - f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. " - "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`" - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -701,31 +683,23 @@ def forward( position_ids=position_ids, ) - if self.config._attn_implementation != "flash_attention_2": - batch_size, seq_length = inputs_embeds.shape[:-1] - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length)), device=inputs_embeds.device) - - attention_mask: torch.Tensor = self.get_extended_attention_mask( - attention_mask, input_shape=(batch_size, seq_length) - ) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + inputs_embeds.shape[:2], + inputs_embeds, + ) encoder_outputs = self.encoder( inputs_embeds, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, + encoder_attention_mask=encoder_attention_mask, **kwargs, ) sequence_output = encoder_outputs[0] @@ -736,6 +710,61 @@ def forward( pooler_output=pooled_output, ) + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if "flash" in self.config._attn_implementation: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if "flash" in self.config._attn_implementation: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + def predict_contacts(self, tokens, attention_mask): attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions attns = torch.stack(attns, dim=1) # Matches the original model layout diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 75db8a22a022..d1d3d8e4e90a 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -20,18 +20,18 @@ # limitations under the License. import math -import warnings from dataclasses import dataclass from typing import Callable, Optional, Union import torch -from torch import Tensor, nn +from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, @@ -41,15 +41,19 @@ ModelOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_evolla import EvollaConfig, SaProtConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + def create_position_ids_from_input_ids(input_ids, padding_idx): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols @@ -225,38 +229,21 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): - # EVOLLA_SA_PROT applies relative position embeddings and we don't copy from Llama + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling - if hasattr(module, "position_embedding_type") and module.position_embedding_type in [ - "relative_key", - "relative_key_query", - ]: - seq_length = query.shape[2] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=attn_weights.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - relative_position_scores = relative_position_scores_query + relative_position_scores_key - - attn_weights = attn_weights + relative_position_scores - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -285,14 +272,12 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cros self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = config.attention_probs_dropout_prob + + self.rotary_embeddings = None self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" ) - self.rotary_embeddings = None - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - elif self.position_embedding_type == "rotary": + if self.position_embedding_type == "rotary": self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size) self.is_decoder = config.is_decoder @@ -330,11 +315,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type in ["relative_key", "relative_key_query"]: - raise ValueError( - f"ESM {self.config._attn_implementation} attention does not support {self.position_embedding_type} embeddings. " - "Set attention explicitly to 'eager' with `model.set_attn_implementation('eager')`" - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -550,6 +530,7 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel): _no_split_modules = ["EvollaSaProtLayer"] _supports_flash_attn = True _supports_sdpa = True + _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { @@ -601,6 +582,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: input_shape = input_ids.size() batch_size, seq_length = input_shape @@ -608,10 +590,14 @@ def forward( device = input_ids.device if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length)), device=device) - inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask) - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_outputs = self.encoder(inputs_embeds, attention_mask=attention_mask, **kwargs) sequence_output = encoder_outputs[0] return BaseModelOutputWithPoolingAndCrossAttentions( @@ -621,61 +607,26 @@ def forward( cross_attentions=encoder_outputs.cross_attentions, ) - def get_extended_attention_mask( + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( self, - attention_mask: Tensor, - input_shape: tuple[int], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (`Tuple[int]`): - The shape of the input to the model. - - Returns: - `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. - """ - if dtype is None: - dtype = get_parameter_dtype(self) - - if not (attention_mask.dim() == 2 and self.config.is_decoder): - # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` - if device is not None: - warnings.warn( - "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning - ) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( - input_shape, attention_mask, device - ) + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if "flash" in self.config._attn_implementation: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" - ) + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min - return extended_attention_mask + return attention_mask class EvollaSequenceCompressorAttention(nn.Module): @@ -1332,9 +1283,9 @@ class EvollaPreTrainedModel(PreTrainedModel): "EvollaSequenceAlignerCrossAttention", ] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder` + _supports_flash_attn = False # see dependency on `EvollaSequenceCompressorResampler` _supports_sdpa = True - _supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder` + _supports_flex_attn = False # see dependency on `EvollaSequenceCompressorResampler` _can_compile_fullgraph = True _supports_attention_backend = False diff --git a/src/transformers/models/evolla/modular_evolla.py b/src/transformers/models/evolla/modular_evolla.py index e2db43a7d787..b69b27dbf26a 100644 --- a/src/transformers/models/evolla/modular_evolla.py +++ b/src/transformers/models/evolla/modular_evolla.py @@ -13,26 +13,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass from typing import Optional, Union import torch -from torch import Tensor, nn +from torch import nn from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithPast, ModelOutput, ) -from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype +from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, can_return_tuple, + is_torch_flex_attn_available, logging, ) from ...utils.deprecation import deprecate_kwarg @@ -59,6 +60,10 @@ from .configuration_evolla import EvollaConfig, SaProtConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -145,14 +150,12 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cros self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = config.attention_probs_dropout_prob + + self.rotary_embeddings = None self.position_embedding_type = position_embedding_type or getattr( config, "position_embedding_type", "absolute" ) - self.rotary_embeddings = None - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - elif self.position_embedding_type == "rotary": + if self.position_embedding_type == "rotary": self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size) self.is_decoder = config.is_decoder @@ -195,6 +198,7 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel): _no_split_modules = ["EvollaSaProtLayer"] _supports_flash_attn = True _supports_sdpa = True + _supports_flex_attn = True _supports_attention_backend = True _can_record_outputs = { @@ -246,6 +250,7 @@ def forward( self, input_ids: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: input_shape = input_ids.size() batch_size, seq_length = input_shape @@ -253,10 +258,14 @@ def forward( device = input_ids.device if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length)), device=device) - inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask) - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - encoder_outputs = self.encoder(inputs_embeds, attention_mask=extended_attention_mask) + + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) + + encoder_outputs = self.encoder(inputs_embeds, attention_mask=attention_mask, **kwargs) sequence_output = encoder_outputs[0] return BaseModelOutputWithPoolingAndCrossAttentions( @@ -266,61 +275,26 @@ def forward( cross_attentions=encoder_outputs.cross_attentions, ) - def get_extended_attention_mask( + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( self, - attention_mask: Tensor, - input_shape: tuple[int], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (`Tuple[int]`): - The shape of the input to the model. - - Returns: - `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. - """ - if dtype is None: - dtype = get_parameter_dtype(self) - - if not (attention_mask.dim() == 2 and self.config.is_decoder): - # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` - if device is not None: - warnings.warn( - "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning - ) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( - input_shape, attention_mask, device - ) + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if "flash" in self.config._attn_implementation: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" - ) + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min - return extended_attention_mask + return attention_mask class EvollaSequenceCompressorAttention(nn.Module): @@ -786,8 +760,8 @@ def forward( class EvollaPreTrainedModel(LlamaPreTrainedModel): - _supports_flash_attn = False # see dependency on `EvollaSaProtProteinEncoder` - _supports_flex_attn = False # see dependency on `EvollaSaProtProteinEncoder` + _supports_flash_attn = False # see dependency on `EvollaSequenceCompressorResampler` + _supports_flex_attn = False # see dependency on `EvollaSequenceCompressorResampler` _supports_attention_backend = False _no_split_modules = [ "EvollaDecoderLayer", diff --git a/src/transformers/models/flava/configuration_flava.py b/src/transformers/models/flava/configuration_flava.py index b7bcb920e47a..1b587ae39424 100644 --- a/src/transformers/models/flava/configuration_flava.py +++ b/src/transformers/models/flava/configuration_flava.py @@ -148,12 +148,6 @@ class FlavaTextConfig(PretrainedConfig): max_position_embeddings (`int`, *optional*, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. num_hidden_layers (`int`, *optional*, defaults to 12): @@ -205,7 +199,6 @@ def __init__( vocab_size: int = 30522, type_vocab_size: int = 2, max_position_embeddings: int = 512, - position_embedding_type: str = "absolute", hidden_size: int = 768, num_hidden_layers: int = 12, num_attention_heads: int = 12, @@ -224,7 +217,6 @@ def __init__( self.vocab_size = vocab_size self.type_vocab_size = type_vocab_size self.max_position_embeddings = max_position_embeddings - self.position_embedding_type = position_embedding_type self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 5d63b5e132ad..2ef41bde4e47 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -378,7 +378,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -411,11 +410,11 @@ def forward( inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/git/configuration_git.py b/src/transformers/models/git/configuration_git.py index 86c85854ff98..2854e005a9f8 100644 --- a/src/transformers/models/git/configuration_git.py +++ b/src/transformers/models/git/configuration_git.py @@ -140,12 +140,6 @@ class GitConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). num_image_with_embedding (`int`, *optional*): @@ -184,7 +178,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, tie_word_embeddings=False, bos_token_id=101, @@ -210,7 +203,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings self.num_image_with_embedding = num_image_with_embedding diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index c1e823767135..5528bc1addff 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -79,7 +79,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -106,16 +105,16 @@ def forward( else: embeddings = inputs_embeds - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class GitSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -142,12 +141,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -187,28 +180,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_values is not None: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in GitModel forward() function) @@ -251,11 +222,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class GitAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type, layer_idx=layer_idx - ) + self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.output = GitSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/ibert/configuration_ibert.py b/src/transformers/models/ibert/configuration_ibert.py index 963e6e6c9ed0..6c5023ce45b3 100644 --- a/src/transformers/models/ibert/configuration_ibert.py +++ b/src/transformers/models/ibert/configuration_ibert.py @@ -65,12 +65,6 @@ class IBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). quant_mode (`bool`, *optional*, defaults to `False`): Whether to quantize the model or not. force_dequant (`str`, *optional*, defaults to `"none"`): @@ -100,7 +94,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", quant_mode=False, force_dequant="none", **kwargs, @@ -119,7 +112,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.quant_mode = quant_mode self.force_dequant = force_dequant diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 761fd515acc6..84f557626145 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -73,7 +73,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") # End copy self.padding_idx = config.pad_token_id @@ -132,14 +131,13 @@ def forward( identity_scaling_factor=token_type_embeddings_scaling_factor, ) - if self.position_embedding_type == "absolute": - position_embeddings, position_embeddings_scaling_factor = self.position_embeddings(position_ids) - embeddings, embeddings_scaling_factor = self.embeddings_act1( - embeddings, - embeddings_scaling_factor, - identity=position_embeddings, - identity_scaling_factor=position_embeddings_scaling_factor, - ) + position_embeddings, position_embeddings_scaling_factor = self.position_embeddings(position_ids) + embeddings, embeddings_scaling_factor = self.embeddings_act1( + embeddings, + embeddings_scaling_factor, + identity=position_embeddings, + identity_scaling_factor=position_embeddings_scaling_factor, + ) embeddings, embeddings_scaling_factor = self.LayerNorm(embeddings, embeddings_scaling_factor) embeddings = self.dropout(embeddings) @@ -217,9 +215,6 @@ def __init__(self, config): self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type != "absolute": - raise ValueError("I-BERT only supports 'absolute' for `config.position_embedding_type`") self.softmax = IntSoftmax(self.act_bit, quant_mode=self.quant_mode, force_dequant=config.force_dequant) diff --git a/src/transformers/models/instructblip/configuration_instructblip.py b/src/transformers/models/instructblip/configuration_instructblip.py index 9b8323f15f05..56e6fc60e574 100644 --- a/src/transformers/models/instructblip/configuration_instructblip.py +++ b/src/transformers/models/instructblip/configuration_instructblip.py @@ -146,12 +146,6 @@ class InstructBlipQFormerConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Token id used for padding sequences. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). cross_attention_frequency (`int`, *optional*, defaults to 2): The frequency of adding cross-attention to the Transformer layers. encoder_hidden_size (`int`, *optional*, defaults to 1408): @@ -188,7 +182,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", cross_attention_frequency=2, encoder_hidden_size=1408, **kwargs, @@ -206,7 +199,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.cross_attention_frequency = cross_attention_frequency self.encoder_hidden_size = encoder_hidden_size diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 5cddbcdfd3bf..6b4799b54602 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -451,10 +451,6 @@ def __init__(self, config, is_cross_attention=False): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): @@ -502,22 +498,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores_dtype = attention_scores.dtype @@ -773,7 +753,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config @@ -794,9 +773,9 @@ def forward( if input_ids is not None: embeddings = self.word_embeddings(input_ids) - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids.to(embeddings.device)) - embeddings = embeddings + position_embeddings + + position_embeddings = self.position_embeddings(position_ids.to(embeddings.device)) + embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index af2acc833876..340f04cb2327 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -152,12 +152,6 @@ class InstructBlipVideoQFormerConfig(PretrainedConfig): The epsilon used by the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Token id used for padding sequences. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). cross_attention_frequency (`int`, *optional*, defaults to 2): The frequency of adding cross-attention to the Transformer layers. encoder_hidden_size (`int`, *optional*, defaults to 1408): @@ -194,7 +188,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", cross_attention_frequency=2, encoder_hidden_size=1408, **kwargs, @@ -212,7 +205,6 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.cross_attention_frequency = cross_attention_frequency self.encoder_hidden_size = encoder_hidden_size diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index abcaa17f70f7..44106573cbea 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -418,10 +418,6 @@ def __init__(self, config, is_cross_attention=False): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): @@ -469,22 +465,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores_dtype = attention_scores.dtype @@ -735,7 +715,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config @@ -756,9 +735,9 @@ def forward( if input_ids is not None: embeddings = self.word_embeddings(input_ids) - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids.to(embeddings.device)) - embeddings = embeddings + position_embeddings + + position_embeddings = self.position_embeddings(position_ids.to(embeddings.device)) + embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) diff --git a/src/transformers/models/lilt/configuration_lilt.py b/src/transformers/models/lilt/configuration_lilt.py index 940fad4aa810..76bdc6094703 100644 --- a/src/transformers/models/lilt/configuration_lilt.py +++ b/src/transformers/models/lilt/configuration_lilt.py @@ -58,12 +58,6 @@ class LiltConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). classifier_dropout (`float`, *optional*): The dropout ratio for the classification head. channel_shrink_ratio (`int`, *optional*, defaults to 4): @@ -102,7 +96,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", classifier_dropout=None, channel_shrink_ratio=4, max_2d_position_embeddings=1024, @@ -122,7 +115,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.classifier_dropout = classifier_dropout self.channel_shrink_ratio = channel_shrink_ratio self.max_2d_position_embeddings = max_2d_position_embeddings diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 191a58836b06..dadae8db246a 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -53,7 +53,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") # End copy self.padding_idx = config.pad_token_id @@ -88,11 +87,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings, position_ids @@ -183,7 +182,7 @@ def forward(self, bbox=None, position_ids=None): class LiltSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -210,12 +209,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.channel_shrink_ratio = config.channel_shrink_ratio self.layer_idx = layer_idx @@ -245,22 +238,6 @@ def forward( attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) layout_attention_scores = torch.matmul(layout_query_layer, layout_key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - tmp_attention_scores = attention_scores / math.sqrt(self.attention_head_size) tmp_layout_attention_scores = layout_attention_scores / math.sqrt( self.attention_head_size // self.channel_shrink_ratio @@ -327,9 +304,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class LiltAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, layer_idx=None): super().__init__() - self.self = LiltSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) + self.self = LiltSelfAttention(config, layer_idx=layer_idx) self.output = LiltSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/megatron_bert/configuration_megatron_bert.py b/src/transformers/models/megatron_bert/configuration_megatron_bert.py index 1505388e2925..f44404d9f76d 100644 --- a/src/transformers/models/megatron_bert/configuration_megatron_bert.py +++ b/src/transformers/models/megatron_bert/configuration_megatron_bert.py @@ -60,12 +60,6 @@ class MegatronBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -104,7 +98,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - position_embedding_type="absolute", use_cache=True, **kwargs, ): @@ -122,7 +115,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 07c011359023..e17a4b206fe7 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -66,7 +66,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") def forward( self, @@ -92,11 +91,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings # Megatron BERT moves that layer norm after the drop-out (and to each layer). # embeddings = self.LayerNorm(embeddings) @@ -106,7 +104,7 @@ def forward( # copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert class MegatronBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, layer_idx=None): + def __init__(self, config, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -123,12 +121,6 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.layer_idx = layer_idx @@ -190,28 +182,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_values is not None: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( - -1, 1 - ) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in MegatronBertModel forward() function) diff --git a/src/transformers/models/mra/configuration_mra.py b/src/transformers/models/mra/configuration_mra.py index 16b064c98f7e..c87e9a291893 100644 --- a/src/transformers/models/mra/configuration_mra.py +++ b/src/transformers/models/mra/configuration_mra.py @@ -60,8 +60,6 @@ class MraConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. block_per_row (`int`, *optional*, defaults to 4): Used to set the budget for the high resolution scale. approx_mode (`str`, *optional*, defaults to `"full"`): @@ -103,7 +101,6 @@ def __init__( type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-5, - position_embedding_type="absolute", block_per_row=4, approx_mode="full", initial_prior_first_n_blocks=0, @@ -127,7 +124,6 @@ def __init__( self.initializer_range = initializer_range self.type_vocab_size = type_vocab_size self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.block_per_row = block_per_row self.approx_mode = approx_mode self.initial_prior_first_n_blocks = initial_prior_first_n_blocks diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 1616bcfdf979..7c8acd6a67af 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -474,7 +474,6 @@ def __init__(self, config): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), @@ -506,18 +505,18 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class MraSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -541,9 +540,6 @@ def __init__(self, config, position_embedding_type=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = ( - position_embedding_type if position_embedding_type is not None else config.position_embedding_type - ) self.num_block = (config.max_position_embeddings // 32) * config.block_per_row self.num_block = min(self.num_block, int((config.max_position_embeddings // 32) ** 2)) @@ -631,9 +627,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class MraAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = MraSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = MraSelfAttention(config) self.output = MraSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index ffd46ed0c278..e27efc082384 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -59,7 +59,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), @@ -91,18 +90,18 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class NystromformerSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -128,9 +127,6 @@ def __init__(self, config, position_embedding_type=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) if self.conv_kernel_size is not None: self.conv = nn.Conv2d( @@ -253,9 +249,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class NystromformerAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = NystromformerSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = NystromformerSelfAttention(config) self.output = NystromformerSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/roberta/configuration_roberta.py b/src/transformers/models/roberta/configuration_roberta.py index 04917804a225..3af141bd0044 100644 --- a/src/transformers/models/roberta/configuration_roberta.py +++ b/src/transformers/models/roberta/configuration_roberta.py @@ -65,12 +65,6 @@ class RobertaConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -113,7 +107,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -132,7 +125,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 7462a68fa97f..7dea29c992fb 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -68,7 +68,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -120,11 +119,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -172,38 +171,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -219,7 +193,7 @@ def eager_attention_forward( class RobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -238,12 +212,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -281,11 +249,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -296,8 +259,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -305,7 +266,7 @@ def forward( class RobertaCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -324,12 +285,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -371,11 +326,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -386,8 +336,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -409,15 +357,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RobertaAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = RobertaSelfOutput(config) self.pruned_heads = set() @@ -504,7 +448,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = RobertaAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -670,8 +613,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = RobertaPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/roberta/modular_roberta.py b/src/transformers/models/roberta/modular_roberta.py index b7b65f004499..cf10f735cbb2 100644 --- a/src/transformers/models/roberta/modular_roberta.py +++ b/src/transformers/models/roberta/modular_roberta.py @@ -93,11 +93,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py index 72bc808c450d..1e95076743dc 100644 --- a/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py @@ -66,12 +66,6 @@ class RobertaPreLayerNormConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -114,7 +108,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -133,7 +126,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 0a5652f117b7..bc863d9f94b1 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -64,7 +64,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -116,11 +115,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -169,38 +168,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -217,7 +191,7 @@ def eager_attention_forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RobertaPreLayerNorm class RobertaPreLayerNormSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -236,12 +210,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -279,11 +247,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -294,8 +257,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -304,7 +265,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->RobertaPreLayerNorm class RobertaPreLayerNormCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -323,12 +284,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -370,11 +325,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -385,8 +335,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -407,15 +355,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class RobertaPreLayerNormAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = RobertaPreLayerNormCrossAttention if is_cross_attention else RobertaPreLayerNormSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = RobertaPreLayerNormSelfOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pruned_heads = set() @@ -507,7 +451,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = RobertaPreLayerNormAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, diff --git a/src/transformers/models/roc_bert/configuration_roc_bert.py b/src/transformers/models/roc_bert/configuration_roc_bert.py index 75f83e11a799..3aaa75a5cac3 100644 --- a/src/transformers/models/roc_bert/configuration_roc_bert.py +++ b/src/transformers/models/roc_bert/configuration_roc_bert.py @@ -65,12 +65,6 @@ class RoCBertConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). classifier_dropout (`float`, *optional*): The dropout ratio for the classification head. enable_pronunciation (`bool`, *optional*, defaults to `True`): @@ -124,7 +118,6 @@ def __init__( layer_norm_eps=1e-12, use_cache=True, pad_token_id=0, - position_embedding_type="absolute", classifier_dropout=None, enable_pronunciation=True, enable_shape=True, @@ -155,7 +148,6 @@ def __init__( self.shape_embed_dim = shape_embed_dim self.shape_vocab_size = shape_vocab_size self.concat_input = concat_input - self.position_embedding_type = position_embedding_type self.classifier_dropout = classifier_dropout super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index b97787f557fd..309449de7609 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -88,7 +88,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), @@ -132,9 +131,8 @@ def forward( inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) @@ -172,9 +170,8 @@ def forward( token_type_embeddings = self.token_type_embeddings(token_type_ids) embedding_in += token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embedding_in += position_embeddings + position_embeddings = self.position_embeddings(position_ids) + embedding_in += position_embeddings embedding_in = self.LayerNorm(embedding_in) embedding_in = self.dropout(embedding_in) @@ -190,38 +187,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -238,7 +210,7 @@ def eager_attention_forward( # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->RoCBert class RoCBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -257,12 +229,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -300,11 +266,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -315,8 +276,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -325,7 +284,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->RoCBert class RoCBertCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -344,12 +303,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -391,11 +344,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -406,8 +354,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -431,15 +377,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT class RoCBertAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = RoCBertCrossAttention if is_cross_attention else RoCBertSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = RoCBertSelfOutput(config) self.pruned_heads = set() @@ -529,7 +471,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = RoCBertAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 490ae8ae4791..2f7ec4bb416b 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -53,7 +53,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") def forward( self, @@ -78,11 +77,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 0f2f86799b7f..54981b1bc141 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -234,7 +234,7 @@ def forward( class SuperGlueSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -251,12 +251,6 @@ def __init__(self, config, position_embedding_type=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder @@ -295,23 +289,6 @@ def forward( # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function) @@ -353,12 +330,9 @@ def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor: class SuperGlueAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, - position_embedding_type=position_embedding_type, - ) + self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation](config) self.output = SuperGlueSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 386883969916..bbe8cd5ed49b 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -232,7 +232,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -265,11 +264,11 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py index 97d6245cb1d7..4cf64e9e1884 100644 --- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py @@ -66,12 +66,6 @@ class XLMRobertaConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -114,7 +108,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -133,7 +126,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 00bbab96668d..9ebe5436c9df 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -65,38 +65,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -112,7 +87,7 @@ def eager_attention_forward( class XLMRobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -131,12 +106,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -174,11 +143,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -189,8 +153,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -198,7 +160,7 @@ def forward( class XLMRobertaCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -217,12 +179,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -264,11 +220,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -279,8 +230,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -302,15 +251,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class XLMRobertaAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = XLMRobertaCrossAttention if is_cross_attention else XLMRobertaSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = XLMRobertaSelfOutput(config) self.pruned_heads = set() @@ -397,7 +342,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = XLMRobertaAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -526,7 +470,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -578,11 +521,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -689,8 +632,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py index 4111a61d4e26..a7dda1b9b318 100644 --- a/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py @@ -65,12 +65,6 @@ class XLMRobertaXLConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -111,7 +105,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs, @@ -129,7 +122,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 7cb1af8a2e68..95b12c762e35 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -71,7 +71,6 @@ def __init__(self, config): self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -123,11 +122,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings embeddings = self.dropout(embeddings) return embeddings @@ -175,38 +173,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -222,7 +195,7 @@ def eager_attention_forward( class XLMRobertaXLSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -241,12 +214,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -284,11 +251,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -299,8 +261,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -308,7 +268,7 @@ def forward( class XLMRobertaXLCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -327,12 +287,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -374,11 +328,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -389,8 +338,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -411,15 +358,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class XLMRobertaXLAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = XLMRobertaXLCrossAttention if is_cross_attention else XLMRobertaXLSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = XLMRobertaXLSelfOutput(config) self.pruned_heads = set() @@ -506,7 +449,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = XLMRobertaXLAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, @@ -677,8 +619,6 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = XLMRobertaXLPooler(config) if add_pooling_layer else None - self.position_embedding_type = config.position_embedding_type - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py index a8fdf8433e29..c1c970cf7138 100644 --- a/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py @@ -102,11 +102,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings embeddings = self.dropout(embeddings) return embeddings @@ -134,10 +133,8 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class XLMRobertaXLAttention(BertAttention): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): - super().__init__(config, position_embedding_type, is_causal, layer_idx, is_cross_attention) + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): + super().__init__(config, is_causal, layer_idx, is_cross_attention) del self.LayerNorm self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/src/transformers/models/xmod/configuration_xmod.py b/src/transformers/models/xmod/configuration_xmod.py index 41bad38a45de..8a0f77e278a9 100644 --- a/src/transformers/models/xmod/configuration_xmod.py +++ b/src/transformers/models/xmod/configuration_xmod.py @@ -65,12 +65,6 @@ class XmodConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For - positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to - [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155). - For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models - with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658). is_decoder (`bool`, *optional*, defaults to `False`): Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. use_cache (`bool`, *optional*, defaults to `True`): @@ -128,7 +122,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - position_embedding_type="absolute", use_cache=True, classifier_dropout=None, pre_norm=False, @@ -154,7 +147,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout self.pre_norm = pre_norm diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 5536eea30452..3af6d00baa2e 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -63,7 +63,6 @@ def __init__(self, config): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False ) @@ -115,11 +114,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -168,38 +167,13 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, dropout: float = 0.0, - use_cache: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ): if scaling is None: scaling = query.size(-1) ** -0.5 # Take the dot product between "query" and "key" to get the raw attention scores. - attn_weights = torch.matmul(query, key.transpose(2, 3)) - - # Relative positional embeddings - if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": - query_length, key_length = query.shape[2], key.shape[2] - if use_cache: - position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) - else: - position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) - position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility - - if module.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - attn_weights = attn_weights + relative_position_scores - elif module.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) - attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key - - # Scaling is shifted in case of embeddings being relative - attn_weights = attn_weights * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask[:, :, :, : key.shape[-2]] @@ -216,7 +190,7 @@ def eager_attention_forward( # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Xmod class XmodSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -235,12 +209,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder self.is_causal = is_causal @@ -278,11 +246,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -293,8 +256,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -303,7 +264,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertCrossAttention with Bert->Xmod class XmodCrossAttention(nn.Module): - def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): + def __init__(self, config, is_causal=False, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -322,12 +283,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_ self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_causal = is_causal self.layer_idx = layer_idx @@ -369,11 +324,6 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.position_embedding_type != "absolute": - raise ValueError( - f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " - 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' - ) attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -384,8 +334,6 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.dropout.p, scaling=self.scaling, - # only for relevant for non-absolute positional embeddings - use_cache=past_key_value is not None, **kwargs, ) attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() @@ -408,15 +356,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class XmodAttention(nn.Module): - def __init__( - self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False - ): + def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False): super().__init__() self.is_cross_attention = is_cross_attention attention_class = XmodCrossAttention if is_cross_attention else XmodSelfAttention - self.self = attention_class( - config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx - ) + self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx) self.output = XmodSelfOutput(config) self.pruned_heads = set() self.pre_norm = config.pre_norm @@ -568,7 +512,6 @@ def __init__(self, config, layer_idx=None): raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = XmodAttention( config, - position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx, is_cross_attention=True, diff --git a/src/transformers/models/yoso/configuration_yoso.py b/src/transformers/models/yoso/configuration_yoso.py index 9a7fb1218e40..e3efb9d09bd2 100644 --- a/src/transformers/models/yoso/configuration_yoso.py +++ b/src/transformers/models/yoso/configuration_yoso.py @@ -60,8 +60,6 @@ class YosoConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - position_embedding_type (`str`, *optional*, defaults to `"absolute"`): - Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. use_expectation (`bool`, *optional*, defaults to `True`): Whether or not to use YOSO Expectation. Overrides any effect of num_hash. hash_code_len (`int`, *optional*, defaults to 9): @@ -106,7 +104,6 @@ def __init__( type_vocab_size=1, initializer_range=0.02, layer_norm_eps=1e-12, - position_embedding_type="absolute", use_expectation=True, hash_code_len=9, num_hash=64, @@ -132,7 +129,6 @@ def __init__( self.initializer_range = initializer_range self.type_vocab_size = type_vocab_size self.layer_norm_eps = layer_norm_eps - self.position_embedding_type = position_embedding_type self.use_expectation = use_expectation self.hash_code_len = hash_code_len self.num_hash = num_hash diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index f830936cc7b7..375321576dc9 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -244,7 +244,6 @@ def __init__(self, config): self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2, persistent=False ) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), @@ -276,18 +275,18 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class YosoSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -310,9 +309,6 @@ def __init__(self, config, position_embedding_type=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = ( - position_embedding_type if position_embedding_type is not None else config.position_embedding_type - ) self.use_expectation = config.use_expectation self.hash_code_len = config.hash_code_len @@ -449,9 +445,9 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class YosoAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config): super().__init__() - self.self = YosoSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = YosoSelfAttention(config) self.output = YosoSelfOutput(config) self.pruned_heads = set() diff --git a/tests/models/albert/test_modeling_albert.py b/tests/models/albert/test_modeling_albert.py index 193143d7b46a..6e0d5ef5603c 100644 --- a/tests/models/albert/test_modeling_albert.py +++ b/tests/models/albert/test_modeling_albert.py @@ -307,13 +307,6 @@ def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - @slow def test_model_from_pretrained(self): model_name = "albert/albert-base-v1" diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 65892b48fbaa..f8beb9457758 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -496,13 +496,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_3d_mask_shapes(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() # manipulate input_mask @@ -588,12 +581,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) @@ -754,40 +741,6 @@ def test_inference_no_head_absolute_embedding(self): torch.testing.assert_close(output[:, 1:4, 1:4], expected_slice, rtol=1e-4, atol=1e-4) - @slow - def test_inference_no_head_relative_embedding_key(self): - model = BertModel.from_pretrained( - "zhiheng-huang/bert-base-uncased-embedding-relative-key", attn_implementation="eager" - ) - input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) - attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - with torch.no_grad(): - output = model(input_ids, attention_mask=attention_mask)[0] - expected_shape = torch.Size((1, 11, 768)) - self.assertEqual(output.shape, expected_shape) - expected_slice = torch.tensor( - [[[0.0756, 0.3142, -0.5128], [0.3761, 0.3462, -0.5477], [0.2052, 0.3760, -0.1240]]] - ) - - torch.testing.assert_close(output[:, 1:4, 1:4], expected_slice, rtol=1e-4, atol=1e-4) - - @slow - def test_inference_no_head_relative_embedding_key_query(self): - model = BertModel.from_pretrained( - "zhiheng-huang/bert-base-uncased-embedding-relative-key-query", attn_implementation="eager" - ) - input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) - attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - with torch.no_grad(): - output = model(input_ids, attention_mask=attention_mask)[0] - expected_shape = torch.Size((1, 11, 768)) - self.assertEqual(output.shape, expected_shape) - expected_slice = torch.tensor( - [[[0.6496, 0.3784, 0.8203], [0.8148, 0.5656, 0.2636], [-0.0681, 0.5597, 0.7045]]] - ) - - torch.testing.assert_close(output[:, 1:4, 1:4], expected_slice, rtol=1e-4, atol=1e-4) - @slow @pytest.mark.torch_export_test def test_export(self): diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index d6a54407015b..b34ea4c47bcf 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -70,7 +70,6 @@ def __init__( rescale_embeddings=False, block_size=8, num_rand_blocks=3, - position_embedding_type="absolute", scope=None, ): self.parent = parent @@ -101,7 +100,6 @@ def __init__( self.rescale_embeddings = rescale_embeddings self.block_size = block_size self.num_rand_blocks = num_rand_blocks - self.position_embedding_type = position_embedding_type def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -145,7 +143,6 @@ def get_config(self): rescale_embeddings=self.rescale_embeddings, block_size=self.block_size, num_random_blocks=self.num_rand_blocks, - position_embedding_type=self.position_embedding_type, ) def prepare_config_and_inputs_for_decoder(self): diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index 26f2053a93aa..ece5c3c9918c 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -283,12 +283,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_biogpt_model_att_mask_past(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_biogpt_model_attention_mask_past(*config_and_inputs) diff --git a/tests/models/bitnet/test_modeling_bitnet.py b/tests/models/bitnet/test_modeling_bitnet.py index 58e3723e8317..efec72ca9b51 100644 --- a/tests/models/bitnet/test_modeling_bitnet.py +++ b/tests/models/bitnet/test_modeling_bitnet.py @@ -168,12 +168,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @require_torch class BitNetIntegrationTest(unittest.TestCase): diff --git a/tests/models/bros/test_modeling_bros.py b/tests/models/bros/test_modeling_bros.py index 681c1e98bdd8..91b5b3861957 100644 --- a/tests/models/bros/test_modeling_bros.py +++ b/tests/models/bros/test_modeling_bros.py @@ -353,12 +353,6 @@ def test_model(self): def test_multi_gpu_data_parallel_forward(self): super().test_multi_gpu_data_parallel_forward() - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 140823b076d7..f7b0eda6dd27 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -342,12 +342,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) diff --git a/tests/models/cohere/test_modeling_cohere.py b/tests/models/cohere/test_modeling_cohere.py index 25d7107a6652..83553e8a107e 100644 --- a/tests/models/cohere/test_modeling_cohere.py +++ b/tests/models/cohere/test_modeling_cohere.py @@ -188,12 +188,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_torch_fx_output_loss(self): super().test_torch_fx_output_loss() diff --git a/tests/models/data2vec/test_modeling_data2vec_text.py b/tests/models/data2vec/test_modeling_data2vec_text.py index f8810685f0fc..4709920e1817 100644 --- a/tests/models/data2vec/test_modeling_data2vec_text.py +++ b/tests/models/data2vec/test_modeling_data2vec_text.py @@ -404,13 +404,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) @@ -450,12 +443,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index 538e07c34655..7f393cb1f3cc 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -97,12 +97,6 @@ class DbrxModelTest(CausalLMModelTest, unittest.TestCase): ) model_tester_class = DbrxModelTester - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @slow def test_model_from_pretrained(self): model_name = "eitanturok/dbrx-tiny" diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 6277b07093db..9c4219fec811 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -280,12 +280,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("yarn",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index b28f7c167b69..4aa6ab23b585 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -218,12 +218,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_diffllama_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 diff --git a/tests/models/electra/test_modeling_electra.py b/tests/models/electra/test_modeling_electra.py index 8019b01a767d..e4695e596e66 100644 --- a/tests/models/electra/test_modeling_electra.py +++ b/tests/models/electra/test_modeling_electra.py @@ -439,13 +439,6 @@ def test_electra_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_electra_model_as_decoder(*config_and_inputs) - def test_electra_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_electra_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs) diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 042bf5d79d43..68f84986054f 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -801,27 +801,6 @@ def prepare_config_and_inputs(self): "labels": decoder_token_labels, } - def test_relative_position_embeds(self): - config_and_inputs = self.prepare_config_and_inputs() - - encoder_config = config_and_inputs["config"] - decoder_config = config_and_inputs["decoder_config"] - - encoder_config._attn_implementation = "eager" - decoder_config._attn_implementation = "eager" - encoder_config.position_embedding_type = "relative_key_query" - decoder_config.position_embedding_type = "relative_key_query" - - encoder_model, decoder_model = self.get_encoder_decoder_model(encoder_config, decoder_config) - model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model).eval().to(torch_device) - model.config._attn_implementation = "eager" # model config -> won't work - - logits = model( - input_ids=config_and_inputs["input_ids"], decoder_input_ids=config_and_inputs["decoder_input_ids"] - ).logits - - self.assertTrue(logits.shape, (13, 7)) - @slow def test_bert2bert_summarization(self): model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") diff --git a/tests/models/ernie/test_modeling_ernie.py b/tests/models/ernie/test_modeling_ernie.py index b38d7f8633eb..7eb89aad9cbb 100644 --- a/tests/models/ernie/test_modeling_ernie.py +++ b/tests/models/ernie/test_modeling_ernie.py @@ -488,13 +488,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) @@ -542,12 +535,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 72ef77c88c0a..c4cbf971f036 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -234,13 +234,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 5333e6ef9242..a80ab28e099b 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -338,7 +338,6 @@ def __init__( vocab_size=102, type_vocab_size=2, max_position_embeddings=512, - position_embedding_type="absolute", hidden_size=32, num_hidden_layers=2, num_attention_heads=4, @@ -360,7 +359,6 @@ def __init__( self.vocab_size = vocab_size self.type_vocab_size = type_vocab_size self.max_position_embeddings = max_position_embeddings - self.position_embedding_type = position_embedding_type self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads @@ -401,7 +399,6 @@ def get_config(self): vocab_size=self.vocab_size, type_vocab_size=self.type_vocab_size, max_position_embeddings=self.max_position_embeddings, - position_embedding_type=self.position_embedding_type, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 3ade347881f4..d9a7c00b8793 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -416,12 +416,6 @@ def test_batched_generate_captioning(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester._test_batched_generate_captioning(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def _check_attentions_for_generate( self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py index 4e6a9cfc5ab8..53f5ee5a2f92 100644 --- a/tests/models/granite/test_modeling_granite.py +++ b/tests/models/granite/test_modeling_granite.py @@ -197,12 +197,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/granite_speech/test_modeling_granite_speech.py b/tests/models/granite_speech/test_modeling_granite_speech.py index 516d44896c2e..db87bec3ba90 100644 --- a/tests/models/granite_speech/test_modeling_granite_speech.py +++ b/tests/models/granite_speech/test_modeling_granite_speech.py @@ -107,7 +107,6 @@ def __init__( "model_type": "blip_2_qformer", "num_attention_heads": 4, "num_hidden_layers": 2, - "position_embedding_type": "absolute", "use_qformer_text_input": False, "vocab_size": 30522, }, diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py index da553d720bd0..cf513daaa8ec 100644 --- a/tests/models/granitemoe/test_modeling_granitemoe.py +++ b/tests/models/granitemoe/test_modeling_granitemoe.py @@ -196,12 +196,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/granitemoeshared/test_modeling_granitemoeshared.py b/tests/models/granitemoeshared/test_modeling_granitemoeshared.py index 4d3f8c4e45be..eb941204bb11 100644 --- a/tests/models/granitemoeshared/test_modeling_granitemoeshared.py +++ b/tests/models/granitemoeshared/test_modeling_granitemoeshared.py @@ -199,12 +199,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/ibert/test_modeling_ibert.py b/tests/models/ibert/test_modeling_ibert.py index b227f3a25147..b56d193aa68a 100644 --- a/tests/models/ibert/test_modeling_ibert.py +++ b/tests/models/ibert/test_modeling_ibert.py @@ -264,13 +264,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - # I-BERT only supports absolute embedding - for type in ["absolute"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py index a7cd87015609..422aaa22eb7b 100644 --- a/tests/models/layoutlm/test_modeling_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_layoutlm.py @@ -256,12 +256,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/layoutlmv2/test_modeling_layoutlmv2.py b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py index 4faf6aa61b4a..e95aaad6b4b5 100644 --- a/tests/models/layoutlmv2/test_modeling_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_modeling_layoutlmv2.py @@ -306,12 +306,6 @@ def test_model(self): def test_multi_gpu_data_parallel_forward(self): pass - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) diff --git a/tests/models/layoutlmv3/test_modeling_layoutlmv3.py b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py index e63ec1b5eb9d..fedbd1975649 100644 --- a/tests/models/layoutlmv3/test_modeling_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py @@ -354,12 +354,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) diff --git a/tests/models/lilt/test_modeling_lilt.py b/tests/models/lilt/test_modeling_lilt.py index 949649a503df..d88f47b3c0d6 100644 --- a/tests/models/lilt/test_modeling_lilt.py +++ b/tests/models/lilt/test_modeling_lilt.py @@ -265,12 +265,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py index db4180d6d97c..d71586cdfbed 100644 --- a/tests/models/modernbert/test_modeling_modernbert.py +++ b/tests/models/modernbert/test_modeling_modernbert.py @@ -294,12 +294,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/mra/test_modeling_mra.py b/tests/models/mra/test_modeling_mra.py index 12b7725e6129..be4ff28ab06e 100644 --- a/tests/models/mra/test_modeling_mra.py +++ b/tests/models/mra/test_modeling_mra.py @@ -288,12 +288,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/nystromformer/test_modeling_nystromformer.py b/tests/models/nystromformer/test_modeling_nystromformer.py index 18214582962e..11f7edc8b701 100644 --- a/tests/models/nystromformer/test_modeling_nystromformer.py +++ b/tests/models/nystromformer/test_modeling_nystromformer.py @@ -252,12 +252,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 2631823ba2f8..7c2802cc1cf6 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -190,12 +190,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/olmo2/test_modeling_olmo2.py b/tests/models/olmo2/test_modeling_olmo2.py index f90e45cdc858..cd0492221fbc 100644 --- a/tests/models/olmo2/test_modeling_olmo2.py +++ b/tests/models/olmo2/test_modeling_olmo2.py @@ -191,12 +191,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/olmoe/test_modeling_olmoe.py b/tests/models/olmoe/test_modeling_olmoe.py index ad02154567c2..47b6c605077a 100644 --- a/tests/models/olmoe/test_modeling_olmoe.py +++ b/tests/models/olmoe/test_modeling_olmoe.py @@ -202,12 +202,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/rembert/test_modeling_rembert.py b/tests/models/rembert/test_modeling_rembert.py index 93a16866601b..e142f866f202 100644 --- a/tests/models/rembert/test_modeling_rembert.py +++ b/tests/models/rembert/test_modeling_rembert.py @@ -381,12 +381,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py index e2e1b8e7b0f1..99032b83e8ed 100644 --- a/tests/models/roberta/test_modeling_roberta.py +++ b/tests/models/roberta/test_modeling_roberta.py @@ -413,13 +413,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) @@ -459,12 +452,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py b/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py index 541f6ba2d8e7..f4d2adebfe52 100644 --- a/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py +++ b/tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py @@ -413,14 +413,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_various_embeddings - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - # Copied from tests.models.roberta.test_modeling_roberta.RobertaModelTest.test_model_as_decoder def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() diff --git a/tests/models/roc_bert/test_modeling_roc_bert.py b/tests/models/roc_bert/test_modeling_roc_bert.py index 09ad188b17b6..dc92ffbefd22 100644 --- a/tests/models/roc_bert/test_modeling_roc_bert.py +++ b/tests/models/roc_bert/test_modeling_roc_bert.py @@ -629,13 +629,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) @@ -648,12 +641,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) diff --git a/tests/models/splinter/test_modeling_splinter.py b/tests/models/splinter/test_modeling_splinter.py index fbb9d4e7c210..59d4537171b2 100644 --- a/tests/models/splinter/test_modeling_splinter.py +++ b/tests/models/splinter/test_modeling_splinter.py @@ -283,12 +283,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) diff --git a/tests/models/visual_bert/test_modeling_visual_bert.py b/tests/models/visual_bert/test_modeling_visual_bert.py index 09c96a2467b0..a49419dd321f 100644 --- a/tests/models/visual_bert/test_modeling_visual_bert.py +++ b/tests/models/visual_bert/test_modeling_visual_bert.py @@ -522,12 +522,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_for_pretraining(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pretraining() self.model_tester.create_and_check_for_pretraining(*config_and_inputs) diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index badc6f067c7f..e556dbf76a1d 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -420,13 +420,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) @@ -466,12 +459,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/xmod/test_modeling_xmod.py b/tests/models/xmod/test_modeling_xmod.py index f0b834feaca7..583bd51e6375 100644 --- a/tests/models/xmod/test_modeling_xmod.py +++ b/tests/models/xmod/test_modeling_xmod.py @@ -418,13 +418,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) @@ -464,12 +457,6 @@ def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_decoder_model_past_with_large_inputs_relative_pos_emb(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() - config_and_inputs[0].position_embedding_type = "relative_key" - config_and_inputs[0]._attn_implementation = "eager" - self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) diff --git a/tests/models/yoso/test_modeling_yoso.py b/tests/models/yoso/test_modeling_yoso.py index 864127fa7c5a..38202962626a 100644 --- a/tests/models/yoso/test_modeling_yoso.py +++ b/tests/models/yoso/test_modeling_yoso.py @@ -286,12 +286,6 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)