Skip to content
Draft

T5 #149

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def __init__(self, config, layer_id):
self.key = nn.Linear(config.hidden_size, self.embed_dim)
self.value = nn.Linear(config.hidden_size, self.embed_dim)

# this is for the T5 setting
if "has_relative_attention_bias" in config.to_dict():
self.is_decoder = config.is_decoder
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.has_relative_attention_bias = config.has_relative_attention_bias
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads)
self.is_t5 = True
else:
self.is_t5 = False

Comment on lines +70 to +80
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest moving all the T5-specific code from here to a longformer_encoder_decoder.LongformerSelfAttentionForT5 and have it inherit from LonformerSelfAttention

self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
Expand All @@ -85,10 +96,72 @@ def __init__(self, config, layer_id):
assert not self.autoregressive # not supported
assert self.attention_dilation == 1 # dilation is not supported

@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is copied with no change, right? Please mention that.

https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)

# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact

# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)

relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets

def compute_bias(self, qlen, klen):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qlen, klen are not used

""" Compute binned relative position bias """
relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment to explain the change

rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=not self.is_decoder,
num_buckets=self.relative_attention_num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
# values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
# Changing the shape to below because that's what LongformerSelfAttention's attn_weights need.
values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen)
return values

def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
past_key_value_state=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
Expand Down Expand Up @@ -185,6 +258,24 @@ def forward(
# concat to attn_weights
# (bsz, seq_len, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)

if self.is_t5:
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")

position_bias = self.compute_bias(seq_len, seq_len)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value_state is not None:
position_bias = position_bias[:, :, -1:, :]
# TODO: attention_mask should also be the same shape as position_bias.
# Sliding attention window??
# if attention_mask is not None:
# position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1)
attn_weights += position_bias


Comment on lines +261 to +278
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as above, move to LongformerSelfAttentionForT5. Here you can only keep only one line, something like:
attn_weights = self.process_relative_positions(attn_weights). This function is empty in LongformerSelfAttention but has more details in LongformerSelfAttentionForT5

attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
Expand Down
76 changes: 75 additions & 1 deletion longformer/longformer_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn, Tensor
from longformer.longformer import LongformerSelfAttention
from transformers.modeling_bart import BartConfig, BartForConditionalGeneration

from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration

class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
def __init__(self, config):
Expand Down Expand Up @@ -74,3 +74,77 @@ def forward(
attn_output = self.output(outputs[0].transpose(0, 1))

return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)


class LongformerEncoderDecoderForConditionalGenerationT5(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
if config.attention_mode == 'n2':
pass # do nothing, use BertSelfAttention instead
else:
for i, layer in enumerate(self.encoder.block):
layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i)


class LongformerEncoderDecoderConfigT5(T5Config):
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
has_relative_attention_bias: bool = True, gradient_checkpointing: bool = False,
**kwargs):
"""
Args:
attention_window: list of attention window sizes of length = number of layers.
window size = number of attention locations on each side.
For an affective window size of 512, use `attention_window=[256]*num_layers`
which is 256 on each side.
attention_dilation: list of attention dilation of length = number of layers.
attention dilation of `1` means no dilation.
autoregressive: do autoregressive attention or have attention of both sides
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
"""
super().__init__(**kwargs)
self.attention_window = attention_window
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
self.has_relative_attention_bias = has_relative_attention_bias
self.gradient_checkpointing = gradient_checkpointing
self.attention_probs_dropout_prob = self.dropout_rate
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']

class LongformerSelfAttentionForT5(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.embed_dim = config.d_model
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
self.output = nn.Linear(self.embed_dim, self.embed_dim)

def forward(
self,
query,
mask=None,
kv=None,
position_bias=None,
past_key_value_state=None,
head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):

tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]

outputs = self.longformer_self_attn(
query, #.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim)
#attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1,
attention_mask=mask, #.unsqueeze(dim=1).unsqueeze(dim=1)*-1,
output_attentions=output_attentions,
)

attn_output = self.output(outputs[0].transpose(0, 1))

return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)

148 changes: 148 additions & 0 deletions scripts/convert_t5_to_longformerencoderdecoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import logging
import os

from transformers import T5Tokenizer

from transformers import T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from longformer.longformer_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def create_long_model(
save_model_to,
base_model,
tokenizer_name_or_path,
attention_window,
max_pos
):
model = T5ForConditionalGeneration.from_pretrained(base_model)
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model)
model.config = config

# in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention
# expects attention_probs_dropout_prob, so set it here
config.attention_probs_dropout_prob = config.dropout_rate
config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ]

# extend position embeddings
tokenizer.model_max_length = max_pos
tokenizer.init_kwargs['model_max_length'] = max_pos

# current_max_pos, embed_size = model.model.embed_positions.weight.shape
# assert current_max_pos == config.max_position_embeddings + 2

# config.max_encoder_position_embeddings = max_pos
# config.max_decoder_position_embeddings = config.max_position_embeddings
# del config.max_position_embeddings
# # TODO: check what's the deal with T5 here.
# max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
# assert max_pos >= current_max_pos

# # allocate a larger position embedding matrix for the encoder
# new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
# # copy position embeddings over and over to initialize the new position embeddings
# k = 2
# step = current_max_pos - 2
# while k < max_pos - 1:
# new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
# k += step
# model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed

# replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention`
config.attention_window = [attention_window] * config.num_hidden_layers
config.attention_dilation = [1] * config.num_hidden_layers
model.encoder.block = model.encoder.block[:1]

for i, layer in enumerate(model.encoder.block):
self_attn = layer.layer[0].SelfAttention

longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i)

longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q
longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k
longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v

longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q
longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k
longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v

longformer_self_attn_for_t5.output = self_attn.o

layer.layer[0].SelfAttention = longformer_self_attn_for_t5

logger.info(f'saving model to {save_model_to}')
model.save_pretrained(save_model_to)
tokenizer.save_pretrained(save_model_to)
return model, tokenizer


def main():
parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention")
parser.add_argument(
'--base_model',
type=str,
default='t5-large',
help='The name or path of the base model you want to convert'
)
parser.add_argument(
'--tokenizer_name_or_path',
type=str,
default='t5-large',
help='The name or path of the tokenizer'
)
parser.add_argument(
'--save_model_to',
type=str,
required=True,
help='The path to save the converted model'
)
parser.add_argument(
'--attention_window',
type=int,
default=512,
help='attention window size for longformer self attention (one sided)'
)
parser.add_argument(
'--max_pos',
type=int,
default=4096 * 4,
help='maximum encoder positions'
)

args = parser.parse_args()

if not os.path.exists(args.save_model_to):
os.mkdir(args.save_model_to)

create_long_model(
save_model_to=args.save_model_to,
base_model=args.base_model,
tokenizer_name_or_path=args.tokenizer_name_or_path,
attention_window=args.attention_window,
max_pos=args.max_pos
)

tokenizer = T5Tokenizer.from_pretrained(args.save_model_to)
TXT = "My friends are <mask> but they eat too many carbs."
model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to)
model.encoder.config.gradient_checkpointing = True
model.decoder.config.gradient_checkpointing = True
data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id)
logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0]
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
print(tokenizer.convert_ids_to_tokens(predictions))


if __name__ == "__main__":
main()