-
Notifications
You must be signed in to change notification settings - Fork 288
T5 #149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
T5 #149
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,17 @@ def __init__(self, config, layer_id): | |
| self.key = nn.Linear(config.hidden_size, self.embed_dim) | ||
| self.value = nn.Linear(config.hidden_size, self.embed_dim) | ||
|
|
||
| # this is for the T5 setting | ||
| if "has_relative_attention_bias" in config.to_dict(): | ||
| self.is_decoder = config.is_decoder | ||
| self.relative_attention_num_buckets = config.relative_attention_num_buckets | ||
| self.has_relative_attention_bias = config.has_relative_attention_bias | ||
| if self.has_relative_attention_bias: | ||
| self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) | ||
| self.is_t5 = True | ||
| else: | ||
| self.is_t5 = False | ||
|
|
||
| self.query_global = nn.Linear(config.hidden_size, self.embed_dim) | ||
| self.key_global = nn.Linear(config.hidden_size, self.embed_dim) | ||
| self.value_global = nn.Linear(config.hidden_size, self.embed_dim) | ||
|
|
@@ -85,10 +96,72 @@ def __init__(self, config, layer_id): | |
| assert not self.autoregressive # not supported | ||
| assert self.attention_dilation == 1 # dilation is not supported | ||
|
|
||
| @staticmethod | ||
| def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): | ||
| """ | ||
| Adapted from Mesh Tensorflow: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is copied with no change, right? Please mention that. |
||
| https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 | ||
| Translate relative position to a bucket number for relative attention. The relative position is defined as | ||
| memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to | ||
| position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for | ||
| small absolute relative_position and larger buckets for larger absolute relative_positions. All relative | ||
| positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. | ||
| This should allow for more graceful generalization to longer sequences than the model has been trained on | ||
| Args: | ||
| relative_position: an int32 Tensor | ||
| bidirectional: a boolean - whether the attention is bidirectional | ||
| num_buckets: an integer | ||
| max_distance: an integer | ||
| Returns: | ||
| a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) | ||
| """ | ||
| relative_buckets = 0 | ||
| if bidirectional: | ||
| num_buckets //= 2 | ||
| relative_buckets += (relative_position > 0).to(torch.long) * num_buckets | ||
| relative_position = torch.abs(relative_position) | ||
| else: | ||
| relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) | ||
| # now relative_position is in the range [0, inf) | ||
|
|
||
| # half of the buckets are for exact increments in positions | ||
| max_exact = num_buckets // 2 | ||
| is_small = relative_position < max_exact | ||
|
|
||
| # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance | ||
| relative_postion_if_large = max_exact + ( | ||
| torch.log(relative_position.float() / max_exact) | ||
| / math.log(max_distance / max_exact) | ||
| * (num_buckets - max_exact) | ||
| ).to(torch.long) | ||
| relative_postion_if_large = torch.min( | ||
| relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) | ||
| ) | ||
|
|
||
| relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) | ||
| return relative_buckets | ||
|
|
||
| def compute_bias(self, qlen, klen): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qlen, klen are not used |
||
| """ Compute binned relative position bias """ | ||
| relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment to explain the change |
||
| rp_bucket = self._relative_position_bucket( | ||
| relative_position, # shape (qlen, klen) | ||
| bidirectional=not self.is_decoder, | ||
| num_buckets=self.relative_attention_num_buckets, | ||
| ) | ||
| rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) | ||
| values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) | ||
| # values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) | ||
| # Changing the shape to below because that's what LongformerSelfAttention's attn_weights need. | ||
| values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen) | ||
| return values | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states, | ||
| attention_mask=None, | ||
| position_bias=None, | ||
| past_key_value_state=None, | ||
| head_mask=None, | ||
| encoder_hidden_states=None, | ||
| encoder_attention_mask=None, | ||
|
|
@@ -185,6 +258,24 @@ def forward( | |
| # concat to attn_weights | ||
| # (bsz, seq_len, num_heads, extra attention count + 2*window+1) | ||
| attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
|
||
| if self.is_t5: | ||
| if position_bias is None: | ||
| if not self.has_relative_attention_bias: | ||
| raise ValueError("No position_bias provided and no weights to compute position_bias") | ||
|
|
||
| position_bias = self.compute_bias(seq_len, seq_len) | ||
| # if key and values are already calculated | ||
| # we want only the last query position bias | ||
| if past_key_value_state is not None: | ||
| position_bias = position_bias[:, :, -1:, :] | ||
| # TODO: attention_mask should also be the same shape as position_bias. | ||
| # Sliding attention window?? | ||
| # if attention_mask is not None: | ||
| # position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1) | ||
| attn_weights += position_bias | ||
|
|
||
|
|
||
|
Comment on lines
+261
to
+278
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as above, move to LongformerSelfAttentionForT5. Here you can only keep only one line, something like: |
||
| attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability | ||
| if key_padding_mask is not None: | ||
| # softmax sometimes inserts NaN if all positions are masked, replace them with 0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| import argparse | ||
| import logging | ||
| import os | ||
|
|
||
| from transformers import T5Tokenizer | ||
|
|
||
| from transformers import T5ForConditionalGeneration | ||
| from transformers.modeling_bart import shift_tokens_right | ||
| from longformer.longformer_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5 | ||
| from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5 | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
||
|
|
||
| def create_long_model( | ||
| save_model_to, | ||
| base_model, | ||
| tokenizer_name_or_path, | ||
| attention_window, | ||
| max_pos | ||
| ): | ||
| model = T5ForConditionalGeneration.from_pretrained(base_model) | ||
| tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) | ||
| config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model) | ||
| model.config = config | ||
|
|
||
| # in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention | ||
| # expects attention_probs_dropout_prob, so set it here | ||
| config.attention_probs_dropout_prob = config.dropout_rate | ||
| config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ] | ||
|
|
||
| # extend position embeddings | ||
| tokenizer.model_max_length = max_pos | ||
| tokenizer.init_kwargs['model_max_length'] = max_pos | ||
|
|
||
| # current_max_pos, embed_size = model.model.embed_positions.weight.shape | ||
| # assert current_max_pos == config.max_position_embeddings + 2 | ||
|
|
||
| # config.max_encoder_position_embeddings = max_pos | ||
| # config.max_decoder_position_embeddings = config.max_position_embeddings | ||
| # del config.max_position_embeddings | ||
| # # TODO: check what's the deal with T5 here. | ||
| # max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 | ||
| # assert max_pos >= current_max_pos | ||
|
|
||
| # # allocate a larger position embedding matrix for the encoder | ||
| # new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) | ||
| # # copy position embeddings over and over to initialize the new position embeddings | ||
| # k = 2 | ||
| # step = current_max_pos - 2 | ||
| # while k < max_pos - 1: | ||
| # new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] | ||
| # k += step | ||
| # model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed | ||
|
|
||
| # replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention` | ||
| config.attention_window = [attention_window] * config.num_hidden_layers | ||
| config.attention_dilation = [1] * config.num_hidden_layers | ||
| model.encoder.block = model.encoder.block[:1] | ||
|
|
||
| for i, layer in enumerate(model.encoder.block): | ||
| self_attn = layer.layer[0].SelfAttention | ||
|
|
||
| longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
|
||
| longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q | ||
| longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k | ||
| longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v | ||
|
|
||
| longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q | ||
| longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k | ||
| longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v | ||
|
|
||
| longformer_self_attn_for_t5.output = self_attn.o | ||
|
|
||
| layer.layer[0].SelfAttention = longformer_self_attn_for_t5 | ||
|
|
||
| logger.info(f'saving model to {save_model_to}') | ||
| model.save_pretrained(save_model_to) | ||
| tokenizer.save_pretrained(save_model_to) | ||
| return model, tokenizer | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention") | ||
| parser.add_argument( | ||
| '--base_model', | ||
| type=str, | ||
| default='t5-large', | ||
| help='The name or path of the base model you want to convert' | ||
| ) | ||
| parser.add_argument( | ||
| '--tokenizer_name_or_path', | ||
| type=str, | ||
| default='t5-large', | ||
| help='The name or path of the tokenizer' | ||
| ) | ||
| parser.add_argument( | ||
| '--save_model_to', | ||
| type=str, | ||
| required=True, | ||
| help='The path to save the converted model' | ||
| ) | ||
| parser.add_argument( | ||
| '--attention_window', | ||
| type=int, | ||
| default=512, | ||
| help='attention window size for longformer self attention (one sided)' | ||
| ) | ||
| parser.add_argument( | ||
| '--max_pos', | ||
| type=int, | ||
| default=4096 * 4, | ||
| help='maximum encoder positions' | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| if not os.path.exists(args.save_model_to): | ||
| os.mkdir(args.save_model_to) | ||
|
|
||
| create_long_model( | ||
| save_model_to=args.save_model_to, | ||
| base_model=args.base_model, | ||
| tokenizer_name_or_path=args.tokenizer_name_or_path, | ||
| attention_window=args.attention_window, | ||
| max_pos=args.max_pos | ||
| ) | ||
|
|
||
| tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) | ||
| TXT = "My friends are <mask> but they eat too many carbs." | ||
| model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to) | ||
| model.encoder.config.gradient_checkpointing = True | ||
| model.decoder.config.gradient_checkpointing = True | ||
| data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) | ||
| input_ids = data['input_ids'] | ||
| attention_mask = data['attention_mask'] | ||
| decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) | ||
| logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] | ||
| masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | ||
| probs = logits[0, masked_index].softmax(dim=0) | ||
| values, predictions = probs.topk(5) | ||
| print(tokenizer.convert_ids_to_tokens(predictions)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest moving all the T5-specific code from here to a
longformer_encoder_decoder.LongformerSelfAttentionForT5and have it inherit fromLonformerSelfAttention