Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 9 additions & 68 deletions examples/modular-transformers/modeling_dummy_bert.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly due to me forgetting to update them in my bert refactor PR --> big diff because the whole refactor is included (same for the roberta example)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated: Only includes the changes here now

Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
Expand Down Expand Up @@ -83,11 +82,11 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings

position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings

embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand All @@ -101,38 +100,13 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5

# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))

# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r

positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility

if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key

# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
Expand All @@ -148,7 +122,7 @@ def eager_attention_forward(


class DummyBertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
def __init__(self, config, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
Expand All @@ -167,12 +141,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

self.is_decoder = config.is_decoder
self.is_causal = is_causal
Expand Down Expand Up @@ -210,11 +178,6 @@ def forward(

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
Expand All @@ -225,16 +188,14 @@ def forward(
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_output, attn_weights


class DummyBertCrossAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
def __init__(self, config, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
Expand All @@ -253,12 +214,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

self.is_causal = is_causal
self.layer_idx = layer_idx
Expand Down Expand Up @@ -300,11 +255,6 @@ def forward(

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
Expand All @@ -315,8 +265,6 @@ def forward(
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
Expand All @@ -338,15 +286,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to


class DummyBertAttention(nn.Module):
def __init__(
self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False
):
def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
super().__init__()
self.is_cross_attention = is_cross_attention
attention_class = DummyBertCrossAttention if is_cross_attention else DummyBertSelfAttention
self.self = attention_class(
config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx
)
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = DummyBertSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -433,7 +377,6 @@ def __init__(self, config, layer_idx=None):
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = DummyBertAttention(
config,
position_embedding_type="absolute",
is_causal=False,
layer_idx=layer_idx,
is_cross_attention=True,
Expand Down Expand Up @@ -638,8 +581,6 @@ def __init__(self, config, add_pooling_layer=True):

self.pooler = DummyBertPooler(config) if add_pooling_layer else None

self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand Down
77 changes: 9 additions & 68 deletions examples/modular-transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, config):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
Expand Down Expand Up @@ -86,11 +85,11 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings

position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings

embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand All @@ -104,38 +103,13 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5

# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))

# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r

positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility

if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key

# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling

if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
Expand All @@ -151,7 +125,7 @@ def eager_attention_forward(


class RobertaSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
def __init__(self, config, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
Expand All @@ -170,12 +144,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

self.is_decoder = config.is_decoder
self.is_causal = is_causal
Expand Down Expand Up @@ -213,11 +181,6 @@ def forward(

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
Expand All @@ -228,16 +191,14 @@ def forward(
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_output, attn_weights


class RobertaCrossAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
def __init__(self, config, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
Expand All @@ -256,12 +217,6 @@ def __init__(self, config, position_embedding_type=None, is_causal=False, layer_
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

self.is_causal = is_causal
self.layer_idx = layer_idx
Expand Down Expand Up @@ -303,11 +258,6 @@ def forward(

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, attn_weights = attention_interface(
Expand All @@ -318,8 +268,6 @@ def forward(
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
Expand All @@ -341,15 +289,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to


class RobertaAttention(nn.Module):
def __init__(
self, config, position_embedding_type=None, is_causal=False, layer_idx=None, is_cross_attention=False
):
def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
super().__init__()
self.is_cross_attention = is_cross_attention
attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention
self.self = attention_class(
config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx
)
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
self.output = RobertaSelfOutput(config)
self.pruned_heads = set()

Expand Down Expand Up @@ -436,7 +380,6 @@ def __init__(self, config, layer_idx=None):
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = RobertaAttention(
config,
position_embedding_type="absolute",
is_causal=False,
layer_idx=layer_idx,
is_cross_attention=True,
Expand Down Expand Up @@ -641,8 +584,6 @@ def __init__(self, config, add_pooling_layer=True):

self.pooler = RobertaPooler(config) if add_pooling_layer else None

self.position_embedding_type = config.position_embedding_type

# Initialize weights and apply final processing
self.post_init()

Expand Down
Loading