From 572da0780230dd9e51277a702cfab082dd2b05f6 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 7 Dec 2020 23:58:07 -0800 Subject: [PATCH 1/3] adding t5 options to longformer --- longformer/longformer.py | 91 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/longformer/longformer.py b/longformer/longformer.py index 14da60f..bdf3691 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -67,6 +67,17 @@ def __init__(self, config, layer_id): self.key = nn.Linear(config.hidden_size, self.embed_dim) self.value = nn.Linear(config.hidden_size, self.embed_dim) + # this is for the T5 setting + if "has_relative_attention_bias" in config.to_dict(): + self.is_decoder = config.is_decoder + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.has_relative_attention_bias = has_relative_attention_bias + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) + self.is_t5 = True + else: + self.is_t5 = False + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) self.key_global = nn.Linear(config.hidden_size, self.embed_dim) self.value_global = nn.Linear(config.hidden_size, self.embed_dim) @@ -85,10 +96,72 @@ def __init__(self, config, layer_id): assert not self.autoregressive # not supported assert self.attention_dilation == 1 # dilation is not supported + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, qlen, klen): + """ Compute binned relative position bias """ + relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]]) + rp_bucket = self._relative_position_bucket( + relative_position, # shape (qlen, klen) + bidirectional=not self.is_decoder, + num_buckets=self.relative_attention_num_buckets, + ) + rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) +# values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) + # Changing the shape to below because that's what LongformerSelfAttention's attn_weights need. + values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen) + return values + def forward( self, hidden_states, attention_mask=None, + position_bias=None, + past_key_value_state=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, @@ -185,6 +258,24 @@ def forward( # concat to attn_weights # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) + + if self.is_t5: + if position_bias is None: + if not self.has_relative_attention_bias: + raise ValueError("No position_bias provided and no weights to compute position_bias") + + position_bias = self.compute_bias(seq_len, seq_len) + # if key and values are already calculated + # we want only the last query position bias + if past_key_value_state is not None: + position_bias = position_bias[:, :, -1:, :] + # TODO: attention_mask should also be the same shape as position_bias. + # Sliding attention window?? + # if attention_mask is not None: + # position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1) + attn_weights += position_bias + + attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 From 2b44bf8e4784116977c456af25cd1771eb9205ef Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 8 Dec 2020 00:00:12 -0800 Subject: [PATCH 2/3] t5 encoder decoder options --- longformer/longformer_encoder_decoder.py | 76 +++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index df38224..b21b2a4 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -2,7 +2,7 @@ from torch import nn, Tensor from longformer.longformer import LongformerSelfAttention from transformers.modeling_bart import BartConfig, BartForConditionalGeneration - +from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration): def __init__(self, config): @@ -74,3 +74,77 @@ def forward( attn_output = self.output(outputs[0].transpose(0, 1)) return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) + + +class LongformerEncoderDecoderForConditionalGenerationT5(T5ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + if config.attention_mode == 'n2': + pass # do nothing, use BertSelfAttention instead + else: + for i, layer in enumerate(self.encoder.block): + layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) + + +class LongformerEncoderDecoderConfigT5(T5Config): + def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, + autoregressive: bool = False, attention_mode: str = 'sliding_chunks', + has_relative_attention_bias: bool = False, gradient_checkpointing: bool = False, + **kwargs): + """ + Args: + attention_window: list of attention window sizes of length = number of layers. + window size = number of attention locations on each side. + For an affective window size of 512, use `attention_window=[256]*num_layers` + which is 256 on each side. + attention_dilation: list of attention dilation of length = number of layers. + attention dilation of `1` means no dilation. + autoregressive: do autoregressive attention or have attention of both sides + attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer + selfattention, 'sliding_chunks' for another implementation of Longformer selfattention + """ + super().__init__(**kwargs) + self.attention_window = attention_window + self.attention_dilation = attention_dilation + self.autoregressive = autoregressive + self.attention_mode = attention_mode + self.has_relative_attention_bias = has_relative_attention_bias + self.gradient_checkpointing = gradient_checkpointing + self.attention_probs_dropout_prob = self.dropout_rate + assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + +class LongformerSelfAttentionForT5(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.embed_dim = config.d_model + self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) + self.output = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + query, + mask=None, + kv=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + outputs = self.longformer_self_attn( + query, #.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim) + #attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1, + attention_mask=mask, #.unsqueeze(dim=1).unsqueeze(dim=1)*-1, + output_attentions=output_attentions, + ) + + attn_output = self.output(outputs[0].transpose(0, 1)) + + return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) + From 2c3860af8f978f1741865503b812b0d8594b477c Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 8 Dec 2020 00:14:32 -0800 Subject: [PATCH 3/3] adding convert script --- longformer/longformer.py | 2 +- longformer/longformer_encoder_decoder.py | 2 +- .../convert_t5_to_longformerencoderdecoder.py | 148 ++++++++++++++++++ 3 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 scripts/convert_t5_to_longformerencoderdecoder.py diff --git a/longformer/longformer.py b/longformer/longformer.py index bdf3691..ce93f8f 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -71,7 +71,7 @@ def __init__(self, config, layer_id): if "has_relative_attention_bias" in config.to_dict(): self.is_decoder = config.is_decoder self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.has_relative_attention_bias = has_relative_attention_bias + self.has_relative_attention_bias = config.has_relative_attention_bias if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) self.is_t5 = True diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py index b21b2a4..145bb6b 100644 --- a/longformer/longformer_encoder_decoder.py +++ b/longformer/longformer_encoder_decoder.py @@ -89,7 +89,7 @@ def __init__(self, config): class LongformerEncoderDecoderConfigT5(T5Config): def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, autoregressive: bool = False, attention_mode: str = 'sliding_chunks', - has_relative_attention_bias: bool = False, gradient_checkpointing: bool = False, + has_relative_attention_bias: bool = True, gradient_checkpointing: bool = False, **kwargs): """ Args: diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py new file mode 100644 index 0000000..6219b0f --- /dev/null +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -0,0 +1,148 @@ +import argparse +import logging +import os + +from transformers import T5Tokenizer + +from transformers import T5ForConditionalGeneration +from transformers.modeling_bart import shift_tokens_right +from longformer.longformer_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5 +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5 + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def create_long_model( + save_model_to, + base_model, + tokenizer_name_or_path, + attention_window, + max_pos +): + model = T5ForConditionalGeneration.from_pretrained(base_model) + tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) + config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model) + model.config = config + + # in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention + # expects attention_probs_dropout_prob, so set it here + config.attention_probs_dropout_prob = config.dropout_rate + config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ] + + # extend position embeddings + tokenizer.model_max_length = max_pos + tokenizer.init_kwargs['model_max_length'] = max_pos + + # current_max_pos, embed_size = model.model.embed_positions.weight.shape + # assert current_max_pos == config.max_position_embeddings + 2 + + # config.max_encoder_position_embeddings = max_pos + # config.max_decoder_position_embeddings = config.max_position_embeddings + # del config.max_position_embeddings + # # TODO: check what's the deal with T5 here. + # max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 + # assert max_pos >= current_max_pos + + # # allocate a larger position embedding matrix for the encoder + # new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) + # # copy position embeddings over and over to initialize the new position embeddings + # k = 2 + # step = current_max_pos - 2 + # while k < max_pos - 1: + # new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] + # k += step + # model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed + + # replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention` + config.attention_window = [attention_window] * config.num_hidden_layers + config.attention_dilation = [1] * config.num_hidden_layers + model.encoder.block = model.encoder.block[:1] + + for i, layer in enumerate(model.encoder.block): + self_attn = layer.layer[0].SelfAttention + + longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i) + + longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q + longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k + longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v + + longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q + longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k + longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v + + longformer_self_attn_for_t5.output = self_attn.o + + layer.layer[0].SelfAttention = longformer_self_attn_for_t5 + + logger.info(f'saving model to {save_model_to}') + model.save_pretrained(save_model_to) + tokenizer.save_pretrained(save_model_to) + return model, tokenizer + + +def main(): + parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention") + parser.add_argument( + '--base_model', + type=str, + default='t5-large', + help='The name or path of the base model you want to convert' + ) + parser.add_argument( + '--tokenizer_name_or_path', + type=str, + default='t5-large', + help='The name or path of the tokenizer' + ) + parser.add_argument( + '--save_model_to', + type=str, + required=True, + help='The path to save the converted model' + ) + parser.add_argument( + '--attention_window', + type=int, + default=512, + help='attention window size for longformer self attention (one sided)' + ) + parser.add_argument( + '--max_pos', + type=int, + default=4096 * 4, + help='maximum encoder positions' + ) + + args = parser.parse_args() + + if not os.path.exists(args.save_model_to): + os.mkdir(args.save_model_to) + + create_long_model( + save_model_to=args.save_model_to, + base_model=args.base_model, + tokenizer_name_or_path=args.tokenizer_name_or_path, + attention_window=args.attention_window, + max_pos=args.max_pos + ) + + tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) + TXT = "My friends are but they eat too many carbs." + model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to) + model.encoder.config.gradient_checkpointing = True + model.decoder.config.gradient_checkpointing = True + data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) + input_ids = data['input_ids'] + attention_mask = data['attention_mask'] + decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) + logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + probs = logits[0, masked_index].softmax(dim=0) + values, predictions = probs.topk(5) + print(tokenizer.convert_ids_to_tokens(predictions)) + + +if __name__ == "__main__": + main()