diff --git a/README.md b/README.md
index 466500d..aed7e73 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,24 @@
#
`Longformer`
`Longformer` is a BERT-like model for long documents.
+
+**\*\*\*\*\* Work In Progress: LongformerEncoderDecoder \*\*\*\*\***
+
+A `LongformerEncoderDecoder` model is now available. It is geared towards summarization where the input is long but the output is relatively shorter. The following code snippet loads a `LongformerEncoderDecoder` checkpointing started from `BART`. With gradient checkpointing, fp16, and 48GB gpu, the input length can be up to 16K tokens.
+```
+pip install git+https://github.com/allenai/longformer.git@encoderdecoder
+
+# checkpoint-base: https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-base-16384.tar.gz
+# checkpoint-large: https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-large-16384.tar.gz
+
+from longformer import LongformerEncoderDecoderForConditionalGeneration
+model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(downloaded_checkpoint, gradient_checkpointing=True)
+```
+
+- Check the script `scripts/summarization.py` for an example of how to use the model.
+
+- Make sure to use the huggingface/transformers fork specified in `requirements.txt`.
+
**\*\*\*\*\* New July 23rd, 2020: Speed degradation \*\*\*\*\***
A significant speed degradation in the hugginface/transformers was recenlty discovered and fixed (check [this PR](https://github.com/huggingface/transformers/pull/5811) for details). To avoid this problem, either use the old [release v2.11.0](https://github.com/huggingface/transformers/tree/v2.11.0) but it doesn't support gradient checkpointing, or use the master branch. This problem should be fixed with the next hugginface/transformers release.
diff --git a/experiment.yml b/experiment.yml
new file mode 100644
index 0000000..156faf5
--- /dev/null
+++ b/experiment.yml
@@ -0,0 +1,18 @@
+tasks:
+ - cluster: {{.Env.CLUSTER}}
+ spec:
+ # This is a python3.7/nvidia base image with basic libraries
+ image: im_j69gti4atcw9
+ resultPath: {{.Env.RESULT_PATH}}
+ args:
+ - /bin/bash
+ - -c
+ - "cd /longformer_on_beaker && pip install . && {{.Env.ARGS}}"
+ datasetMounts:
+ - datasetId: {{.Env.INPUT_DATASET_ID}}
+ containerPath: /data
+ - datasetId: {{.Env.SCRIPTS}}
+ containerPath: /longformer_on_beaker
+ requirements:
+ gpuCount: {{.Env.GPU_COUNT}}
+ cpu: {{.Env.CPU_COUNT}}
diff --git a/longformer/__init__.py b/longformer/__init__.py
index e69de29..d3e343c 100644
--- a/longformer/__init__.py
+++ b/longformer/__init__.py
@@ -0,0 +1,3 @@
+from longformer.longformer import Longformer, LongformerForMaskedLM, LongformerConfig
+from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig
+from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
\ No newline at end of file
diff --git a/longformer/longformer.py b/longformer/longformer.py
index 953bd2c..ecf19d1 100644
--- a/longformer/longformer.py
+++ b/longformer/longformer.py
@@ -5,6 +5,7 @@
import torch.nn.functional as F
from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
+from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM
@@ -48,7 +49,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
- assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap']
class LongformerSelfAttention(nn.Module):
@@ -58,7 +59,6 @@ def __init__(self, config, layer_id):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
- self.output_attentions = config.output_attentions
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
@@ -80,8 +80,8 @@ def __init__(self, config, layer_id):
self.autoregressive = config.autoregressive
assert self.attention_window > 0
assert self.attention_dilation > 0
- assert self.attention_mode in ['tvm', 'sliding_chunks']
- if self.attention_mode == 'sliding_chunks':
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
+ if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
assert not self.autoregressive # not supported
assert self.attention_dilation == 1 # dilation is not supported
@@ -147,8 +147,12 @@ def forward(
q = q.float().contiguous()
k = k.float().contiguous()
attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
- else: # "sliding_chunks"
+ elif self.attention_mode == "sliding_chunks":
attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
+ elif self.attention_mode == "sliding_chunks_no_overlap":
+ attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
+ else:
+ raise False
mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
@@ -162,10 +166,14 @@ def forward(
# diagonal mask with zeros everywhere and -inf inplace of padding
if self.attention_mode == 'tvm':
d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
- else:
+ elif self.attention_mode == "sliding_chunks":
d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
+ elif self.attention_mode == "sliding_chunks_no_overlap":
+ d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
+
attn_weights += d_mask
- assert list(attn_weights.size()) == [bsz, seq_len, self.num_heads, self.attention_window * 2 + 1]
+ assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
+ assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
# the extra attention
if extra_attention_mask is not None:
@@ -182,7 +190,6 @@ def forward(
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
-
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
@@ -199,8 +206,12 @@ def forward(
if self.attention_mode == 'tvm':
v = v.float().contiguous()
attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
- else: # "sliding_chunks"
+ elif self.attention_mode == "sliding_chunks":
attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
+ elif self.attention_mode == "sliding_chunks_no_overlap":
+ attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
+ else:
+ raise False
attn = attn.type_as(hidden_states)
assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py
new file mode 100644
index 0000000..df38224
--- /dev/null
+++ b/longformer/longformer_encoder_decoder.py
@@ -0,0 +1,76 @@
+from typing import List, Optional, Tuple, Dict
+from torch import nn, Tensor
+from longformer.longformer import LongformerSelfAttention
+from transformers.modeling_bart import BartConfig, BartForConditionalGeneration
+
+
+class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
+ def __init__(self, config):
+ super().__init__(config)
+ if config.attention_mode == 'n2':
+ pass # do nothing, use BertSelfAttention instead
+ else:
+ for i, layer in enumerate(self.model.encoder.layers):
+ layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
+
+
+class LongformerEncoderDecoderConfig(BartConfig):
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
+ gradient_checkpointing: bool = False, **kwargs):
+ """
+ Args:
+ attention_window: list of attention window sizes of length = number of layers.
+ window size = number of attention locations on each side.
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
+ which is 256 on each side.
+ attention_dilation: list of attention dilation of length = number of layers.
+ attention dilation of `1` means no dilation.
+ autoregressive: do autoregressive attention or have attention of both sides
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
+ """
+ super().__init__(**kwargs)
+ self.attention_window = attention_window
+ self.attention_dilation = attention_dilation
+ self.autoregressive = autoregressive
+ self.attention_mode = attention_mode
+ self.gradient_checkpointing = gradient_checkpointing
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
+
+
+class LongformerSelfAttentionForBart(nn.Module):
+ def __init__(self, config, layer_id):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
+ attn_mask: Optional[Tensor] = None,
+ need_weights=False,
+ output_attentions=False,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ assert attn_mask is None
+
+ outputs = self.longformer_self_attn(
+ query.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim)
+ attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=output_attentions,
+ )
+
+ attn_output = self.output(outputs[0].transpose(0, 1))
+
+ return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)
diff --git a/longformer/longformer_t5_encoder_decoder.py b/longformer/longformer_t5_encoder_decoder.py
new file mode 100644
index 0000000..85cf786
--- /dev/null
+++ b/longformer/longformer_t5_encoder_decoder.py
@@ -0,0 +1,379 @@
+import math
+from typing import List, Optional, Tuple, Dict
+from torch import nn, Tensor
+from longformer.longformer import LongformerSelfAttention
+from longformer.sliding_chunks import *
+from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration
+
+
+class LongformerEncoderDecoderForConditionalGenerationT5(T5ForConditionalGeneration):
+ def __init__(self, config):
+ super().__init__(config)
+ if config.attention_mode == 'n2':
+ pass # do nothing, use BertSelfAttention instead
+ else:
+ for i, layer in enumerate(self.encoder.block):
+ layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i)
+
+
+class LongformerEncoderDecoderConfigT5(T5Config):
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
+ gradient_checkpointing: bool = False, **kwargs):
+ """
+ Args:
+ attention_window: list of attention window sizes of length = number of layers.
+ window size = number of attention locations on each side.
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
+ which is 256 on each side.
+ attention_dilation: list of attention dilation of length = number of layers.
+ attention dilation of `1` means no dilation.
+ autoregressive: do autoregressive attention or have attention of both sides
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
+ """
+ super().__init__(**kwargs)
+ self.attention_window = attention_window
+ self.attention_dilation = attention_dilation
+ self.autoregressive = autoregressive
+ self.attention_mode = attention_mode
+ self.gradient_checkpointing = gradient_checkpointing
+ self.attention_probs_dropout_prob = self.dropout_rate
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
+
+class LongformerSelfAttentionT5Basic(nn.Module):
+ def __init__(self, config, layer_id, has_relative_attention_bias=False):
+ super(LongformerSelfAttentionT5Basic, self).__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ self.num_heads = config.num_attention_heads
+ self.head_dim = int(config.hidden_size / config.num_attention_heads)
+ self.embed_dim = config.hidden_size
+
+ self.query = nn.Linear(config.hidden_size, self.embed_dim)
+ self.key = nn.Linear(config.hidden_size, self.embed_dim)
+ self.value = nn.Linear(config.hidden_size, self.embed_dim)
+
+ # this is for the T5 setting
+ self.is_decoder = config.is_decoder
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
+ self.has_relative_attention_bias = has_relative_attention_bias
+
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads)
+
+ self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
+ self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
+ self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
+
+ self.dropout = config.attention_probs_dropout_prob
+
+ self.layer_id = layer_id
+ self.attention_window = config.attention_window[self.layer_id]
+ self.attention_dilation = config.attention_dilation[self.layer_id]
+ self.attention_mode = config.attention_mode
+ self.autoregressive = config.autoregressive
+ assert self.attention_window > 0
+ assert self.attention_dilation > 0
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
+ if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
+ assert not self.autoregressive # not supported
+ assert self.attention_dilation == 1 # dilation is not supported
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ """
+ Adapted from Mesh Tensorflow:
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
+ Args:
+ relative_position: an int32 Tensor
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
+ """
+ relative_buckets = 0
+ if bidirectional:
+ num_buckets //= 2
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
+ relative_position = torch.abs(relative_position)
+ else:
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
+ # now relative_position is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = relative_position < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_position.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
+ return relative_buckets
+
+ def compute_bias(self, qlen, klen):
+ """ Compute binned relative position bias """
+ relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]])
+ rp_bucket = self._relative_position_bucket(
+ relative_position, # shape (qlen, klen)
+ bidirectional=not self.is_decoder,
+ num_buckets=self.relative_attention_num_buckets,
+ )
+ rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
+ values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
+# values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
+ # Changing the shape to below because that's what LongformerSelfAttention's attn_weights need.
+ values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen)
+ return values
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ position_bias=None,
+ past_key_value_state=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ '''
+ The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to
+ -ve: no attention
+ 0: local attention
+ +ve: global attention
+ '''
+ if attention_mask is not None:
+ attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
+ key_padding_mask = attention_mask < 0
+ extra_attention_mask = attention_mask > 0
+ remove_from_windowed_attention_mask = attention_mask != 0
+
+ num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1)
+ max_num_extra_indices_per_batch = num_extra_indices_per_batch.max()
+ if max_num_extra_indices_per_batch <= 0:
+ extra_attention_mask = None
+ else:
+ # To support the case of variable number of global attention in the rows of a batch,
+ # we use the following three selection masks to select global attention embeddings
+ # in a 3d tensor and pad it to `max_num_extra_indices_per_batch`
+ # 1) selecting embeddings that correspond to global attention
+ extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True)
+ zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch,
+ device=num_extra_indices_per_batch.device)
+ # mask indicating which values are actually going to be padding
+ selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1)
+ # 2) location of the non-padding values in the selected global attention
+ selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True)
+ # 3) location of the padding values in the selected global attention
+ selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True)
+ else:
+ remove_from_windowed_attention_mask = None
+ extra_attention_mask = None
+ key_padding_mask = None
+
+ hidden_states = hidden_states.transpose(0, 1)
+ seq_len, bsz, embed_dim = hidden_states.size()
+ assert embed_dim == self.embed_dim
+ q = self.query(hidden_states)
+ k = self.key(hidden_states)
+ v = self.value(hidden_states)
+ q /= math.sqrt(self.head_dim)
+
+ q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
+ k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
+ # attn_weights = (bsz, seq_len, num_heads, window*2+1)
+ if self.attention_mode == 'tvm':
+ q = q.float().contiguous()
+ k = k.float().contiguous()
+ attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
+ elif self.attention_mode == "sliding_chunks":
+ attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
+ elif self.attention_mode == "sliding_chunks_no_overlap":
+ attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
+ else:
+ raise False
+ mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
+ if remove_from_windowed_attention_mask is not None:
+ # This implementation is fast and takes very little memory because num_heads x hidden_size = 1
+ # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size)
+ remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1)
+ # cast to float/half then replace 1's with -inf
+ float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0)
+ repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation)
+ float_mask = float_mask.repeat(1, 1, repeat_size, 1)
+ ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
+ # diagonal mask with zeros everywhere and -inf inplace of padding
+ if self.attention_mode == 'tvm':
+ d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
+ elif self.attention_mode == "sliding_chunks":
+ d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
+ elif self.attention_mode == "sliding_chunks_no_overlap":
+ d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
+
+ attn_weights += d_mask
+ assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
+ assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]
+
+ # the extra attention
+ if extra_attention_mask is not None:
+ selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
+ selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros]
+ # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
+ selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k))
+ selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000
+ # concat to attn_weights
+ # (bsz, seq_len, num_heads, extra attention count + 2*window+1)
+ attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
+
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ raise ValueError("No position_bias provided and no weights to compute position_bias")
+
+ position_bias = self.compute_bias(seq_len, seq_len)
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+ if past_key_value_state is not None:
+ position_bias = position_bias[:, :, -1:, :]
+
+ # TODO: attention_mask should also be the same shape as position_bias.
+ # Sliding attention window??
+ # if attention_mask is not None:
+ # position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1)
+
+ attn_weights += position_bias
+
+ attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
+ if key_padding_mask is not None:
+ # softmax sometimes inserts NaN if all positions are masked, replace them with 0
+ attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+ v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
+ attn = 0
+ if extra_attention_mask is not None:
+ selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch)
+ selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim)
+ selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
+ # use `matmul` because `einsum` crashes sometimes with fp16
+ # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
+ attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
+ attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
+
+ if self.attention_mode == 'tvm':
+ v = v.float().contiguous()
+ attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
+ elif self.attention_mode == "sliding_chunks":
+ attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
+ elif self.attention_mode == "sliding_chunks_no_overlap":
+ attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
+ else:
+ raise False
+
+ attn = attn.type_as(hidden_states)
+ assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
+ attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous()
+
+ # For this case, we'll just recompute the attention for these indices
+ # and overwrite the attn tensor. TODO: remove the redundant computation
+ if extra_attention_mask is not None:
+ selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim)
+ selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]]
+
+ q = self.query_global(selected_hidden_states)
+ k = self.key_global(hidden_states)
+ v = self.value_global(hidden_states)
+ q /= math.sqrt(self.head_dim)
+
+ q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim)
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim)
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len]
+
+ attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
+ attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0
+ if key_padding_mask is not None:
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -10000.0,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len)
+ attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+ selected_attn = torch.bmm(attn_probs, v)
+ assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim]
+
+ selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim)
+ nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]]
+ attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states)
+
+ context_layer = attn.transpose(0, 1)
+ if output_attentions:
+ if extra_attention_mask is not None:
+ # With global attention, return global attention probabilities only
+ # batch_size x num_heads x max_num_global_attention_tokens x sequence_length
+ # which is the attention weights from tokens with global attention to all tokens
+ # It doesn't not return local attention
+ # In case of variable number of global attantion in the rows of a batch,
+ # attn_weights are padded with -10000.0 attention scores
+ attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len)
+ else:
+ # without global attention, return local attention probabilities
+ # batch_size x num_heads x sequence_length x window_size
+ # which is the attention weights of every token attending to its neighbours
+ attn_weights = attn_weights.permute(0, 2, 1, 3)
+ outputs = (context_layer, attn_weights) if output_attentions else (context_layer,)
+ return outputs
+
+
+class LongformerSelfAttentionForT5(nn.Module):
+ def __init__(self, config, layer_id):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.longformer_self_attn = LongformerSelfAttentionT5Basic(config, layer_id=layer_id,
+ has_relative_attention_bias=True) #config.has_relative_attention_bias)
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ query,
+ mask=None,
+ kv=None,
+ position_bias=None,
+ past_key_value_state=None,
+ head_mask=None,
+ query_length=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ outputs = self.longformer_self_attn(
+ query, #.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim)
+ #attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1,
+ attention_mask=mask, #.unsqueeze(dim=1).unsqueeze(dim=1)*-1,
+ output_attentions=output_attentions,
+ )
+
+ attn_output = self.output(outputs[0].transpose(0, 1))
+
+ return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)
diff --git a/longformer/sliding_chunks.py b/longformer/sliding_chunks.py
index d39fe9b..8ee30a1 100644
--- a/longformer/sliding_chunks.py
+++ b/longformer/sliding_chunks.py
@@ -125,9 +125,52 @@ def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
Returns
(input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
'''
- w = 2 * one_sided_window_size
+ w = int(2 * one_sided_window_size)
seqlen = input_ids.size(1)
padding_len = (w - seqlen % w) % w
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
return input_ids, attention_mask
+
+
+# ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
+# This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
+# To make this implemenation comparable to "sliding_chunks" set w such that
+# w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
+# For example,
+# w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
+# w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
+# Performance:
+# - Speed: 30% faster than "sliding_chunks"
+# - Memory: 95% of the memory usage of "sliding_chunks"
+# The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
+# while "sliding_chunks" has a symmetric window around each token.
+# This implementation is roughly similar to the implementation described in the BigBird paper https://arxiv.org/abs/2007.14062
+
+def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
+ bsz, seqlen, num_heads, head_dim = q.size()
+ assert seqlen % w == 0
+ assert q.size() == k.size()
+ # chunk seqlen into non-overlapping chunks of size w
+ chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
+ chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
+ chunk_k_expanded = torch.stack((
+ F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
+ chunk_k,
+ F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
+ ), dim=-1)
+ diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
+ return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)
+
+
+def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
+ bsz, seqlen, num_heads, head_dim = v.size()
+ chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
+ chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
+ chunk_v_extended = torch.stack((
+ F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
+ chunk_v,
+ F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
+ ), dim=-1)
+ context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
+ return context.reshape(bsz, seqlen, num_heads, head_dim)
diff --git a/longformer_on_beaker.sh b/longformer_on_beaker.sh
new file mode 100755
index 0000000..bedf8d9
--- /dev/null
+++ b/longformer_on_beaker.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+export SCRIPTS=$(beaker dataset create -q .)
+export INPUT_DATASET_ID="ds_drt127wv4aun"
+export RESULT_SAVE_DIR="/runs"
+export RESULT_SAVE_PREFIX="test"
+export ARGS="$@"
+export GPU_COUNT=8
+export CPU_COUNT=32
+export CLUSTER="ai2/on-prem-ai2-server3"
+export RESULT_PATH=$RESULT_SAVE_DIR/$RESULT_SAVE_PREFIX
+
+beaker experiment create -f experiment.yml
diff --git a/requirements.txt b/requirements.txt
index 5b004e7..3eb9122 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,7 @@
-torch>=1.2.0
-transformers>=3.0.2
+transformers @ git+http://github.com/ibeltagy/transformers.git@longformer_encoder_decoder#egg=transformers
+pytorch-lightning @ git+http://github.com/ibeltagy/pytorch-lightning.git@v0.8.5_fixes#egg=pytorch-lightning
+torch==1.6.0
tensorboardX
-pytorch-lightning==0.6.0
test-tube==0.7.5
+nlp==0.3.0
+rouge_score
diff --git a/scripts/cheatsheet.txt b/scripts/cheatsheet.txt
index be4fc3a..c0ab4e5 100644
--- a/scripts/cheatsheet.txt
+++ b/scripts/cheatsheet.txt
@@ -70,3 +70,18 @@ python -m scripts.triviaqa_utils.evaluation_utils \
--prediction_file predictions.json
# Output should be:
{'exact_match': 73.07644188665083, 'f1': 77.78523804802242, 'common': 7993, 'denominator': 7993, 'pred_len': 7993, 'gold_len': 7993}
+
+
+# TPU
+import torch_xla.debug.metrics as met; print(met.metrics_report())
+curl -X POST http://10.125.212.42:8475/requestversion/pytorch-dev20200722
+
+/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py --outfile debug.tar.gz -- python -u scripts/test_tpu.py
+
+/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py --outfile debug.tar.gz -- python -u scripts/pretrain.py --input_dir data/ --save_prefix test_xla_2 --gpu_count 0 --tpu_core_count 1 --val_batches 4 --val_every 130 --num_workers 0 --log_rate 1 --model allenai/longformer-base-4096
+
+python scripts/pretrain.py --input_dir data/ --save_prefix test_grad_accum --gpu_count 0 --tpu_core_count 8 --val_batches 30 --val_every 30 --num_workers 0 --log_rate 1
+
+export TPU_IP_ADDRESS=10.125.212.42
+export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
+source /anaconda3/bin/activate torch-xla-nightly
diff --git a/scripts/convert_bart_to_longformerencoderdecoder.py b/scripts/convert_bart_to_longformerencoderdecoder.py
new file mode 100644
index 0000000..fc94996
--- /dev/null
+++ b/scripts/convert_bart_to_longformerencoderdecoder.py
@@ -0,0 +1,152 @@
+import argparse
+import logging
+import os
+
+from transformers import BartTokenizer
+
+from transformers import BartForConditionalGeneration
+from transformers.modeling_bart import shift_tokens_right
+from longformer.longformer_encoder_decoder import LongformerSelfAttentionForBart, LongformerEncoderDecoderConfig
+from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def create_long_model(
+ save_model_to,
+ base_model,
+ tokenizer_name_or_path,
+ attention_window,
+ max_pos
+):
+ model = BartForConditionalGeneration.from_pretrained(base_model)
+ tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
+ config = LongformerEncoderDecoderConfig.from_pretrained(base_model)
+ model.config = config
+
+ # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention
+ # expects attention_probs_dropout_prob, so set it here
+ config.attention_probs_dropout_prob = config.attention_dropout
+ config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ]
+
+ # extend position embeddings
+ tokenizer.model_max_length = max_pos
+ tokenizer.init_kwargs['model_max_length'] = max_pos
+ current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape
+ assert current_max_pos == config.max_position_embeddings + 2
+
+ config.max_encoder_position_embeddings = max_pos
+ config.max_decoder_position_embeddings = config.max_position_embeddings
+ del config.max_position_embeddings
+ max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
+ assert max_pos >= current_max_pos
+
+ # allocate a larger position embedding matrix for the encoder
+ new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
+ # copy position embeddings over and over to initialize the new position embeddings
+ k = 2
+ step = current_max_pos - 2
+ while k < max_pos - 1:
+ new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
+ k += step
+ model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed
+
+ # allocate a larger position embedding matrix for the decoder
+ # new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size)
+ # # copy position embeddings over and over to initialize the new position embeddings
+ # k = 2
+ # step = current_max_pos - 2
+ # while k < max_pos - 1:
+ # new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:]
+ # k += step
+ # model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed
+
+ # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention`
+ config.attention_window = [attention_window] * config.num_hidden_layers
+ config.attention_dilation = [1] * config.num_hidden_layers
+
+ for i, layer in enumerate(model.model.encoder.layers):
+ longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i)
+
+ longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
+ longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
+ longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj
+
+ longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj
+ longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj
+ longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj
+
+ longformer_self_attn_for_bart.output = layer.self_attn.out_proj
+
+ layer.self_attn = longformer_self_attn_for_bart
+ logger.info(f'saving model to {save_model_to}')
+ model.save_pretrained(save_model_to)
+ tokenizer.save_pretrained(save_model_to)
+ return model, tokenizer
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention")
+ parser.add_argument(
+ '--base_model',
+ type=str,
+ default='facebook/bart-large',
+ help='The name or path of the base model you want to convert'
+ )
+ parser.add_argument(
+ '--tokenizer_name_or_path',
+ type=str,
+ default='facebook/bart-large',
+ help='The name or path of the tokenizer'
+ )
+ parser.add_argument(
+ '--save_model_to',
+ type=str,
+ required=True,
+ help='The path to save the converted model'
+ )
+ parser.add_argument(
+ '--attention_window',
+ type=int,
+ default=512,
+ help='attention window size for longformer self attention (one sided)'
+ )
+ parser.add_argument(
+ '--max_pos',
+ type=int,
+ default=4096 * 4,
+ help='maximum encoder positions'
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.save_model_to):
+ os.mkdir(args.save_model_to)
+
+ create_long_model(
+ save_model_to=args.save_model_to,
+ base_model=args.base_model,
+ tokenizer_name_or_path=args.tokenizer_name_or_path,
+ attention_window=args.attention_window,
+ max_pos=args.max_pos
+ )
+
+ tokenizer = BartTokenizer.from_pretrained(args.save_model_to)
+ TXT = "My friends are but they eat too many carbs."
+ model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to)
+ model.model.encoder.config.gradient_checkpointing = True
+ model.model.decoder.config.gradient_checkpointing = True
+ data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048)
+ input_ids = data['input_ids']
+ attention_mask = data['attention_mask']
+ decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id)
+ logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0]
+ masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ probs = logits[0, masked_index].softmax(dim=0)
+ values, predictions = probs.topk(5)
+ print(tokenizer.convert_ids_to_tokens(predictions))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py
new file mode 100644
index 0000000..c06930b
--- /dev/null
+++ b/scripts/convert_t5_to_longformerencoderdecoder.py
@@ -0,0 +1,157 @@
+import argparse
+import logging
+import os
+
+from transformers import T5Tokenizer
+
+from transformers import T5ForConditionalGeneration
+from transformers.modeling_bart import shift_tokens_right
+from longformer.longformer_t5_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5
+from longformer.longformer_t5_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+def create_long_model(
+ save_model_to,
+ base_model,
+ tokenizer_name_or_path,
+ attention_window,
+ max_pos
+):
+ model = T5ForConditionalGeneration.from_pretrained(base_model)
+ tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos)
+ config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model)
+ model.config = config
+
+ # in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention
+ # expects attention_probs_dropout_prob, so set it here
+ config.attention_probs_dropout_prob = config.dropout_rate
+ config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ]
+
+ # extend position embeddings
+ tokenizer.model_max_length = max_pos
+ tokenizer.init_kwargs['model_max_length'] = max_pos
+ # current_max_pos, embed_size = model.model.embed_positions.weight.shape
+ # assert current_max_pos == config.max_position_embeddings + 2
+
+ # config.max_encoder_position_embeddings = max_pos
+ # config.max_decoder_position_embeddings = config.max_position_embeddings
+ # del config.max_position_embeddings
+ # # TODO: check what's the deal with T5 here.
+ # max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
+ # assert max_pos >= current_max_pos
+
+ # # allocate a larger position embedding matrix for the encoder
+ # new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
+ # # copy position embeddings over and over to initialize the new position embeddings
+ # k = 2
+ # step = current_max_pos - 2
+ # while k < max_pos - 1:
+ # new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:]
+ # k += step
+ # model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed
+
+ # allocate a larger position embedding matrix for the decoder
+ # new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size)
+ # # copy position embeddings over and over to initialize the new position embeddings
+ # k = 2
+ # step = current_max_pos - 2
+ # while k < max_pos - 1:
+ # new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:]
+ # k += step
+ # model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed
+
+ # replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention`
+ config.attention_window = [attention_window] * config.num_hidden_layers
+ config.attention_dilation = [1] * config.num_hidden_layers
+ # model.encoder.block = model.encoder.block[:1]
+
+ for i, layer in enumerate(model.encoder.block):
+ self_attn = layer.layer[0].SelfAttention
+
+ longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i)
+
+ longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q
+ longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k
+ longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v
+
+ longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q
+ longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k
+ longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v
+
+ longformer_self_attn_for_t5.output = self_attn.o
+
+ layer.layer[0].SelfAttention = longformer_self_attn_for_t5
+
+ logger.info(f'saving model to {save_model_to}')
+ model.save_pretrained(save_model_to)
+ tokenizer.save_pretrained(save_model_to)
+ return model, tokenizer
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention")
+ parser.add_argument(
+ '--base_model',
+ type=str,
+ default='t5-large',
+ help='The name or path of the base model you want to convert'
+ )
+ parser.add_argument(
+ '--tokenizer_name_or_path',
+ type=str,
+ default='t5-large',
+ help='The name or path of the tokenizer'
+ )
+ parser.add_argument(
+ '--save_model_to',
+ type=str,
+ required=True,
+ help='The path to save the converted model'
+ )
+ parser.add_argument(
+ '--attention_window',
+ type=int,
+ default=512,
+ help='attention window size for longformer self attention (one sided)'
+ )
+ parser.add_argument(
+ '--max_pos',
+ type=int,
+ default=4096 * 4,
+ help='maximum encoder positions'
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.save_model_to):
+ os.mkdir(args.save_model_to)
+
+ create_long_model(
+ save_model_to=args.save_model_to,
+ base_model=args.base_model,
+ tokenizer_name_or_path=args.tokenizer_name_or_path,
+ attention_window=args.attention_window,
+ max_pos=args.max_pos
+ )
+
+ tokenizer = T5Tokenizer.from_pretrained(args.save_model_to)
+ TXT = "My friends are but they eat too many carbs."
+ model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to)
+ model.encoder.config.gradient_checkpointing = True
+ model.decoder.config.gradient_checkpointing = True
+ data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048)
+ input_ids = data['input_ids']
+ attention_mask = data['attention_mask']
+ decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id)
+ logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0]
+ masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ probs = logits[0, masked_index].softmax(dim=0)
+ values, predictions = probs.topk(5)
+ print(tokenizer.convert_ids_to_tokens(predictions))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/mem_profiler.py b/scripts/mem_profiler.py
new file mode 100644
index 0000000..5d8e2f7
--- /dev/null
+++ b/scripts/mem_profiler.py
@@ -0,0 +1,69 @@
+from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
+from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig
+
+from longformer.longformer import LongformerForMaskedLM
+from longformer.longformer import LongformerConfig
+
+import torch
+from torch.utils.data import DataLoader, Dataset
+from pytorch_lightning import Trainer
+import pytorch_lightning as pl
+
+seqlen = 1024 * 2
+global_size = seqlen // 100
+attention_window = 256 # one sided
+
+
+class CoolDataset(Dataset):
+ def __len__(self):
+ return 1024 # number of examples
+
+ def __getitem__(self, idx):
+ tokne_ids = torch.tensor([5] * seqlen)
+ mask = torch.tensor([1] * seqlen)
+ mask[:global_size] = 2
+ return tokne_ids, mask
+
+
+class MemoryProfiler(pl.LightningModule):
+
+ def __init__(self, hparams=None):
+ super().__init__()
+ self.hparams = hparams
+
+ config = LongformerEncoderDecoderConfig.from_pretrained('bart-long-4096')
+ # config = LongformerConfig.from_pretrained('roberta-large')
+ config.max_position_embeddings = seqlen + 2
+ config.gradient_checkpointing = True
+ config.attention_mode = 'sliding_chunks'
+ # config.attention_mode = 'n2'
+ config.attention_window = [attention_window] * config.num_hidden_layers
+ config.attention_dilation = [1] * config.num_hidden_layers
+ self.model = LongformerEncoderDecoderForConditionalGeneration(config)
+ # self.model = LongformerForMaskedLM(config)
+
+ def forward(self, x, y):
+ print(seqlen, global_size, attention_window, torch.cuda.max_memory_allocated(x.device) / 1024 ** 3)
+ # import ipdb; ipdb.set_trace()
+ # return self.model(x, attention_mask=y, decoder_input_ids=x[:, :attention_window * 2], use_cache=False)
+ return self.model(x, attention_mask=y)
+
+ def training_step(self, batch, batch_idx):
+ # import ipdb; ipdb.set_trace()
+ x, y = batch
+ y_hat = self(x, y)
+ loss = y_hat[0].sum()
+ # import ipdb; ipdb.set_trace()
+ return {'loss': loss}
+
+ def configure_optimizers(self):
+ return torch.optim.Adam(self.parameters(), lr=0.001)
+
+ def train_dataloader(self):
+ return DataLoader(CoolDataset(), batch_size=2, num_workers=0)
+
+
+if __name__ == '__main__':
+ model = MemoryProfiler(hparams={})
+ trainer = Trainer(gpus=[0], progress_bar_refresh_rate=1, max_epochs=1, amp_level='O2', use_amp=True)
+ trainer.fit(model)
diff --git a/scripts/pretrain.py b/scripts/pretrain.py
new file mode 100644
index 0000000..8de5bbd
--- /dev/null
+++ b/scripts/pretrain.py
@@ -0,0 +1,461 @@
+import argparse
+import glob
+import os
+import random
+import logging
+import numpy as np
+import math
+from tqdm import tqdm
+import time
+import torch
+from transformers import AutoTokenizer, AutoModelForMaskedLM
+from transformers import DataCollatorForLanguageModeling
+from transformers.optimization import AdamW, get_linear_schedule_with_warmup
+
+from torch.utils.data import Dataset, DataLoader
+import pytorch_lightning as ptl
+from pytorch_lightning.logging.test_tube import TestTubeLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateLogger
+
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+# DONE: reproduce RoBERTa numbers on the Longformer corpus
+# DONE: testing ddp single machine
+# DONE: testing ddp multiple machines
+# DONE: testing resume from checkpoint
+# TODO: try on a TPU-pod
+# TODO: run on beaker on ai2-server1/2
+
+
+try:
+ import torch_xla.core.xla_model as xm
+except ImportError:
+ XLA_AVAILABLE = False
+else:
+ XLA_AVAILABLE = True
+
+
+class MMapTextDataset(Dataset):
+ def __init__(self, mmap_filename, chunk_size, bos_token_id, eos_token_id):
+ # `chunk_size - 2` to reserve space for and
+ self.num_instances = np.memmap(mmap_filename, mode='r', dtype=np.uint16).shape[0] // (chunk_size - 2)
+ # defer loading the token_ids memmap until after the first __getitem__ call.
+ # when spawning new processes for ddp, there is a hard limit in python < 3.8 that
+ # pickle files need to be < 4GB. By waiting until after the first __getitem__ we
+ # don't have to pickle the memmap
+ self.token_ids = None
+ self._mmap_filename = mmap_filename
+ self._chunk_size = chunk_size
+ self._bos_token_id = bos_token_id
+ self._eos_token_id = eos_token_id
+
+ def __len__(self):
+ return self.num_instances
+
+ def __getitem__(self, i):
+ if self.token_ids is None:
+ self.token_ids = np.memmap(self._mmap_filename, mode='r', dtype=np.uint16)
+ from_index = i * (self._chunk_size - 2)
+ to_index = (i + 1) * (self._chunk_size - 2)
+ data = np.concatenate(([self._bos_token_id], self.token_ids[from_index:to_index], [self._eos_token_id]))
+ return torch.tensor(data, dtype=torch.long)
+
+ # ========================= preprocessing code ========================= #
+ @staticmethod
+ def _process_file(full_fname):
+ "Step 1: tokenize an input text file then save token ids into `np.memmap` shards of size `args.shard_size`"
+ fname = full_fname.split('/')[-1]
+ log_filename = f'{args.input_dir}/logs-{args.shard_size}/{fname}.log'
+ if os.path.isfile(log_filename):
+ logging.info(f'Skipping {full_fname} ...')
+ return # log file already exists. Skip current file.
+
+ logging.info(f'Processing {full_fname} ...')
+ with open(full_fname, 'r') as fin:
+ token_list = []
+ shard_count = 0
+ tokens_count = 0
+
+ def _write_shard():
+ if len(token_list) == 0:
+ return
+ if token_list[-1] != MMapTextDataset.tokenizer.sep_token_id: # handle a rare case
+ token_list.append(MMapTextDataset.tokenizer.sep_token_id)
+ shared_filename = f'{args.input_dir}/shards-{args.shard_size}/{fname}-{shard_count}.bin'
+ logging.info(f'Writing {len(token_list)} tokens to shared {shared_filename}')
+ fp = np.memmap(shared_filename, dtype=np.uint16, mode='w+', shape=len(token_list))
+ fp[:] = token_list[:]
+ del fp # flush and close file
+ for line in tqdm(fin):
+ line = line.strip()
+ if line == '': # drop empty lines
+ continue
+ tokens = MMapTextDataset.tokenizer.encode(line, add_special_tokens=False) # `__getitem__` adds special tokens
+ token_list.extend(tokens)
+ if len(token_list) > args.shard_size:
+ _write_shard()
+ tokens_count += len(token_list)
+ token_list = []
+ shard_count += 1
+ else:
+ token_list.append(MMapTextDataset.tokenizer.sep_token_id)
+ _write_shard()
+ tokens_count += len(token_list)
+ with open(log_filename, 'w') as f:
+ f.write(f'Generated {tokens_count} tokens in {shard_count + 1} shards')
+
+ @staticmethod
+ def _combine_shards(output_fname, shards_list):
+ "Step 2: combining memmap shards into one `train.bin` or `val.bin` file"
+ total_size = 0
+ for filename in shards_list:
+ total_size += np.memmap(filename, mode='r', dtype=np.uint16).shape[0]
+ logging.info(f'Writing {total_size} tokens to {output_fname}')
+ all_token_ids = np.empty(total_size, dtype=np.uint16)
+ last_token_index = 0
+ for filename in tqdm(shards_list):
+ shared = np.memmap(filename, mode='r', dtype=np.uint16)
+ all_token_ids[last_token_index:last_token_index+len(shared)] = shared[:]
+ last_token_index += len(shared)
+ fp = np.memmap(output_fname, dtype=np.uint16, mode='w+', shape=total_size)
+ fp[:] = all_token_ids[:]
+ del fp
+
+ @staticmethod
+ def raw_text_to_mmap(args):
+ """This is the main preprocessing function. It processes all the text files in `args.input_dir` and
+ outputs two np.memmap files, one for training and one for validation with ratio `args.train_dev_split`.
+ Processing each input file involves tokenizing it, sharding it into shards of size `args.shard_size`,
+ then writing each shard as an np.memmap file. The stream of tokens in the memmap file represents documents
+ separated with `tokenizer.sep_token`. In `__getitem__`, the `tokenizer.bos_token` and `tokenizer.eos_token`
+ are added. The reason for not adding them at preprocessing time is to allow different sequence lengths
+ later on. Notice that this is the "FULL-SENTENCES" setting in the RoBERTa paper, Table2.
+ """
+ MMapTextDataset.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
+ assert len(MMapTextDataset.tokenizer) < 65535 # will use uint16 to store token ids
+ all_files = glob.glob(f'{args.input_dir}/*.txt')
+
+ if os.path.exists(f'{args.input_dir}/cache/train.bin') and os.path.exists(f'{args.input_dir}/cache/val.bin'):
+ logger.info("Cache already exists. Remove the cache directory to regenerate")
+ return
+ try:
+ os.mkdir(f'{args.input_dir}/cache/')
+ except FileExistsError:
+ pass
+ try:
+ os.mkdir(f'{args.input_dir}/shards-{args.shard_size}/')
+ except FileExistsError:
+ pass
+ try:
+ os.mkdir(f'{args.input_dir}/logs-{args.shard_size}/') # log progrss to be able to resume
+ except FileExistsError:
+ pass
+
+ # STEP1: tokenizing and saving to shards
+ if args.num_preprocessing_workers > 1:
+ from multiprocessing.pool import Pool
+ with Pool(args.num_preprocessing_workers) as p:
+ list(tqdm(p.imap(MMapTextDataset._process_file, all_files), total=len(all_files)))
+ else:
+ [MMapTextDataset._process_file(f) for f in tqdm(all_files)]
+
+ # STEP2: shuffling shards and combining them into train.bin and val.bin files
+ all_shards = glob.glob(f'{args.input_dir}/shards-{args.shard_size}/*.bin')
+ random.shuffle(all_shards) # shuffling based on shards not individual lines
+ val_shards_count = int(args.train_dev_split * len(all_shards))
+ val_shards = all_shards[:val_shards_count]
+ train_shards = all_shards[val_shards_count:]
+ # TODO: if MMapTextDataset._combining_shards is very slow for large files, it can be skipped but we nned to
+ # update the dataset to read from multiple shards directly
+ MMapTextDataset._combine_shards(f'{args.input_dir}/cache/val.bin', val_shards)
+ MMapTextDataset._combine_shards(f'{args.input_dir}/cache/train.bin', train_shards)
+
+ del MMapTextDataset.tokenizer
+ # ========================= end preprocessing code ========================= #
+
+
+class Pretrainer(ptl.LightningModule):
+
+ def __init__(self, hparams):
+ super().__init__()
+
+ self.args = hparams
+ self.hparams = self.args
+
+ self.model = AutoModelForMaskedLM.from_pretrained(args.model)
+ self.config = self.model.config
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
+ self.pad_token_id = tokenizer.pad_token_id
+ self.eos_token_id = tokenizer.eos_token_id
+ self.bos_token_id = tokenizer.bos_token_id
+
+ logger.info(f'Creating dataset cache from dir {self.args.input_dir}. This could be slow the first time.')
+ MMapTextDataset.raw_text_to_mmap(args)
+
+ # TODO: add support for other objective functions (whole word masking, BART objectives)
+ self.data_collator = DataCollatorForLanguageModeling(
+ tokenizer=tokenizer, mlm=True, mlm_probability=self.args.mlm_prob
+ )
+ self.start_time = 0
+
+ def to(self, *args, **kwargs):
+ param_count_before_to = len(list(self.parameters()))
+ super().to(*args, **kwargs)
+ if self.trainer.use_tpu:
+ # need to re-tie the weights after moving to XLA!
+ self.model.tie_weights()
+ if 'roberta' in self.args.model:
+ self.model.lm_head.bias = self.model.lm_head.decoder.bias
+ param_count_after_to = len(list(self.parameters()))
+ assert param_count_before_to == param_count_after_to
+
+ def forward(self, input_ids=None, labels=None):
+ # get the padding mask - 1 for NOT masked, 0 for MASKED/PAD
+ attention_mask = (input_ids != self.pad_token_id).int()
+
+ # output is loss, prediction_scores, hidden_states
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
+ return output[0] # loss
+
+ def training_step(self, batch, batch_nb):
+ loss = self(**batch)
+ input_ids = batch['input_ids']
+ tensorboard_logs = {
+ 'input_size': input_ids.numel(),
+ 'mlm_loss': loss,
+ 'mlm_bpc': loss/math.log(2),
+ 'mlm_perplexity': torch.exp(loss),
+ 'token_per_step': input_ids.numel() * self.args.grad_accum * self.trainer.world_size,
+ }
+ if self.start_time != 0:
+ elapsed_time = time.time() - self.start_time
+ tensorboard_logs['second_per_batch'] = elapsed_time
+ self.start_time = time.time()
+ if self.on_gpu:
+ tensorboard_logs['memory'] = torch.cuda.memory_allocated(loss.device) / 1024 ** 3
+
+ return {'loss': loss, 'log': tensorboard_logs}
+
+ def validation_step(self, batch, batch_nb):
+ # TODO: log how long evaluation takes
+ self.start_time = 0 # reset training_step timer
+ loss = self(**batch)
+ tensorboard_logs = {
+ 'val_mlm_loss': loss.detach(),
+ }
+ return {'val_loss': tensorboard_logs["val_mlm_loss"], 'log': tensorboard_logs}
+
+ def validation_epoch_end(self, outputs):
+ avg_loss = torch.stack([x['log']['val_mlm_loss'] for x in outputs if 'val_mlm_loss' in x['log']]).mean()
+ if self.use_ddp:
+ # TODO: PTL is already doing this. Is it still needed here?
+ # https://github.com/PyTorchLightning/pytorch-lightning/blob/0.8.5/pytorch_lightning/metrics/converters.py#L251
+ torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM)
+ avg_loss /= torch.distributed.get_world_size()
+ elif self.use_tpu:
+ avg_loss = xm.all_reduce(xm.REDUCE_SUM, avg_loss) / xm.xrt_world_size()
+
+ logs = {'val_mlm_loss': avg_loss}
+ return {'log': logs, 'progress_bar': logs, "val_loss": avg_loss}
+
+ def configure_optimizers(self):
+ no_decay = ["bias", "LayerNorm.weight"]
+
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
+ "weight_decay": 0.0,
+ },
+ ]
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=self.args.adam_epsilon)
+ scheduler = get_linear_schedule_with_warmup(
+ optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.train_steps
+ )
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
+
+ def _get_loader(self, fname, is_train):
+ dataset = MMapTextDataset(fname, chunk_size=self.args.seqlen,
+ bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id)
+
+ # TODO: consider `replace_sampler_ddp=True` and removing the following if statement
+ if self.trainer.use_ddp:
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train)
+ shuffle = False
+ elif self.trainer.use_tpu:
+ sampler = torch.utils.data.distributed.DistributedSampler(
+ dataset,
+ num_replicas=xm.xrt_world_size(),
+ rank=xm.get_ordinal(),
+ shuffle=is_train,
+ )
+ shuffle = False
+ else:
+ sampler = None
+ shuffle = is_train
+
+ loader = DataLoader(
+ dataset,
+ batch_size=self.args.batch_size,
+ shuffle=shuffle,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ collate_fn=self.data_collator,
+ drop_last=is_train,
+ )
+ return loader
+
+ def train_dataloader(self):
+ return self._get_loader(f'{self.args.input_dir}/cache/train.bin', True)
+
+ def val_dataloader(self):
+ return self._get_loader(f'{self.args.input_dir}/cache/val.bin', False)
+
+ def grad_norm(self, norm_type):
+ # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params
+ # TODO: grad_norm reporting needs to take fp16 loss scale into account
+ parameters = [p for p in self.parameters() if p.grad is not None]
+ device = parameters[0].device
+ total_norm = torch.zeros([], device=device if parameters else None)
+ norm_type = float(norm_type)
+ for p in parameters:
+ param_norm = p.grad.data.pow(norm_type).sum()
+ total_norm.add_(param_norm)
+ total_norm = (total_norm ** (1.0 / norm_type))
+ return {'total_grad_norm': total_norm}
+
+ @staticmethod
+ def add_args(parser):
+ parser.add_argument("--seed", type=int, default=3)
+
+ # Dataset. Some of these params are only useful when generating the dataset cache
+ parser.add_argument("--input_dir", type=str, default='/net/nfs.corp/s2-research/beltagy/longformer/data/')
+ # Used only at the preprocessing phase
+ parser.add_argument("--train_dev_split", type=float, default=0.05)
+ parser.add_argument("--shard_size", type=int, default=1024 ** 3 // 4) # 250MB
+ parser.add_argument("--num_preprocessing_workers", type=int, default=1)
+ # Used only at the training phase
+ parser.add_argument("--seqlen", type=int, default=512)
+ parser.add_argument("--mlm_prob", type=float, default=0.15)
+
+ # HF model loading
+ parser.add_argument("--tokenizer", type=str, default='roberta-base')
+ parser.add_argument("--model", type=str, default='roberta-base')
+
+ # Checkpointing and logging
+ parser.add_argument("--save_dir", type=str, default='/runs/')
+ parser.add_argument("--save_prefix", type=str, default='test',
+ help="path of output directory is --save_dir/--save_prefix")
+ parser.add_argument("--resume", type=str, default=None, # It is better to use a different output dir.
+ help="Path to a checkpoint to load model weights and training state. It overwrites args")
+ parser.add_argument("--resume_model_only", type=str, default=None,
+ help="Path to a checkpoint to load model weights but not training state")
+ parser.add_argument("--log_rate", type=int, default=10)
+ parser.add_argument("--disable_checkpointing", type=bool, default=False)
+
+ # Training hyperparams
+ parser.add_argument("--lr", type=float, default=1e-5)
+ parser.add_argument("--train_steps", type=int, default=3000, help='# training grad. updates')
+ parser.add_argument("--warmup_steps", type=int, default=1000, help='# warmup grad. updates')
+ parser.add_argument("--val_every", type=int, default=1000, help='# training grad. updates between evaluations')
+ parser.add_argument("--val_batches", type=int, default=1000, help='# evaluation **batches**')
+ parser.add_argument("--weight_decay", type=float, default=0.01)
+ parser.add_argument("--adam_epsilon", type=float, default=1e-6)
+ parser.add_argument("--grad_clip", type=float, default=0) # TODO: test this with fp16. Likely not working
+
+ # RoBERTa's tokens_per_step = 2^18 = 512(seqlen) x 1(gpu_count) x 32(batch_size) x 16(grad_accum)
+ parser.add_argument("--batch_size", type=int, default=32)
+ parser.add_argument("--grad_accum", type=int, default=1)
+
+ # Compute resources
+ parser.add_argument("--fp16", type=bool, default=False)
+ parser.add_argument("--num_workers", type=int, default=0)
+ parser.add_argument("--gpu_count", type=int, default=1, # `--gpus` is reserved for internal use by PTL
+ help="Number of gpus. This respects `CUDA_VISIBLE_DEVICES`")
+
+ # For multi-node training, use the PyTorch launch script. The script and instructions can be found here:
+ # https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py.
+ # To run PTL in a mode compatible with the launch script, two things are needed:
+ # - pass the argument `--use_env` to `torch.distributed.launch`
+ # - make sure `--nproc_per_node` matches `--gpu_count` and `--nnodes` matches `--node_count`.
+ # For example, to run on 2 nodes, 3 gpus each, the command line on node rank 1 would be like:
+ # >>>> python -m torch.distributed.launch \
+ # --use_env --nnodes 2 --nproc_per_node 3 \
+ # --node_rank 1 --master_addr s2-server4 --master_port 12343 \
+ # scripts/pretrain.py \
+ # --gpu_count 2 --node_count 2 \
+ # --input_dir my_data_dir --save_prefix test_multinode
+ parser.add_argument("--node_count", type=int, default=1,
+ help="Number of nodes. It needs to match --nnodes of torch.distributed.launch")
+ parser.add_argument("--tpu_core_count", type=int, default=None)
+
+ return parser
+
+
+def main(args):
+ random.seed(args.seed * 10)
+ np.random.seed(args.seed * 100)
+ torch.manual_seed(args.seed * 1000)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed * 10000)
+
+ if args.resume_model_only is not None:
+ pretrainer = Pretrainer.load_from_checkpoint(args.resume_model_only, args)
+ else:
+ pretrainer = Pretrainer(args)
+
+ # logger here is a SummaryWritter for tensorboard
+ # it is used by the trainer, and certain return variables
+ # from the model are automatically logged
+ logger = TestTubeLogger(
+ save_dir=args.save_dir,
+ name=args.save_prefix,
+ version=0 # always use version=0
+ )
+
+ checkpoint_callback = ModelCheckpoint(
+ # model saved to filepath/prefix_....
+ filepath=os.path.join(args.save_dir, args.save_prefix, 'checkpoint'),
+ prefix='',
+ save_top_k=1,
+ save_last=True,
+ verbose=True,
+ monitor='val_loss',
+ mode='min',
+ period=-1, # to allow multiple checkpoints per epoch
+ )
+
+ args.val_every *= args.grad_accum # PTL is expecting number of batches_per_gpu
+ trainer = ptl.Trainer(
+ gpus=args.gpu_count,
+ num_nodes=args.node_count,
+ num_tpu_cores=args.tpu_core_count,
+ distributed_backend='ddp' if (args.gpu_count > 1 or args.node_count > 1) else None,
+ replace_sampler_ddp=False,
+ track_grad_norm=2,
+ max_epochs=10000, min_epochs=0, max_steps=args.train_steps, # run for many epochs, but stop after max_steps
+ val_check_interval=args.val_every, limit_val_batches=args.val_batches,
+ early_stop_callback=None,
+ row_log_interval=args.log_rate,
+ progress_bar_refresh_rate=args.log_rate,
+ logger=logger,
+ checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else None,
+ accumulate_grad_batches=args.grad_accum,
+ resume_from_checkpoint=args.resume,
+ gradient_clip_val=args.grad_clip,
+ precision=16 if args.fp16 else 32, amp_level='O2',
+ num_sanity_val_steps=2,
+ callbacks=[LearningRateLogger()],
+ )
+ trainer.fit(pretrainer)
+
+
+if __name__ == "__main__":
+ parser = Pretrainer.add_args(argparse.ArgumentParser(description="pretrain"))
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/summarization.py b/scripts/summarization.py
new file mode 100644
index 0000000..8374022
--- /dev/null
+++ b/scripts/summarization.py
@@ -0,0 +1,348 @@
+import os
+import argparse
+import random
+import numpy as np
+
+import torch
+from torch.utils.data import DataLoader, Dataset
+from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
+from transformers.optimization import get_linear_schedule_with_warmup, Adafactor
+import nlp
+from rouge_score import rouge_scorer
+
+import pytorch_lightning as pl
+from pytorch_lightning.logging import TestTubeLogger
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
+
+
+from longformer import LongformerEncoderDecoderForConditionalGeneration, LongformerEncoderDecoderConfig
+from longformer.sliding_chunks import pad_to_window_size
+
+
+def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
+ """From fairseq"""
+ if target.dim() == lprobs.dim() - 1:
+ target = target.unsqueeze(-1)
+ nll_loss = -lprobs.gather(dim=-1, index=target)
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
+ if ignore_index is not None:
+ pad_mask = target.eq(ignore_index)
+ nll_loss.masked_fill_(pad_mask, 0.0)
+ smooth_loss.masked_fill_(pad_mask, 0.0)
+ count = (~pad_mask).sum()
+ else:
+ nll_loss = nll_loss.squeeze(-1)
+ smooth_loss = smooth_loss.squeeze(-1)
+ count = nll_loss.numel()
+
+ nll_loss = nll_loss.sum() / count
+ smooth_loss = smooth_loss.sum() / count
+ eps_i = epsilon / lprobs.size(-1)
+ loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
+ return loss, nll_loss
+
+
+class SummarizationDataset(Dataset):
+ def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len):
+ self.hf_dataset = hf_dataset
+ self.tokenizer = tokenizer
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+
+ def __len__(self):
+ return len(self.hf_dataset)
+
+ def __getitem__(self, idx):
+ entry = self.hf_dataset[idx]
+ input_ids = self.tokenizer.encode(entry['article'], truncation=True, max_length=self.max_input_len)
+ output_ids = self.tokenizer.encode(entry['abstract'], truncation=True, max_length=self.max_output_len)
+ if self.tokenizer.bos_token_id is None: # pegasus
+ output_ids = [self.tokenizer.pad_token_id] + output_ids
+ return torch.tensor(input_ids), torch.tensor(output_ids)
+
+ @staticmethod
+ def collate_fn(batch):
+ # A hack to know if this is bart or pegasus. DDP doesn't like global variables nor class-level memebr variables
+ if batch[0][0][-1].item() == 2:
+ pad_token_id = 1 # AutoTokenizer.from_pretrained('facebook/bart-base').pad_token_id
+ elif batch[0][0][-1].item() == 1:
+ pad_token_id = 0 # AutoTokenizer.from_pretrained('google/pegasus-large').pad_token_id
+ else:
+ assert False
+
+ input_ids, output_ids = list(zip(*batch))
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
+ output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id)
+ return input_ids, output_ids
+
+
+class Summarizer(pl.LightningModule):
+
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ self.hparams = args
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer, use_fast=True)
+
+ if 'long' in self.args.model_path:
+ config = LongformerEncoderDecoderConfig.from_pretrained(self.args.model_path)
+ config.attention_dropout = self.args.attention_dropout
+ config.gradient_checkpointing = self.args.grad_ckpt
+ config.attention_mode = self.args.attention_mode
+ config.attention_window = [self.args.attention_window] * config.encoder_layers
+ self.model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(
+ self.args.model_path, config=config)
+ else:
+ config = AutoConfig.from_pretrained(self.args.model_path)
+ config.attention_dropout = self.args.attention_dropout
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
+ self.args.model_path, config=config)
+ self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None
+
+ def _prepare_input(self, input_ids):
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
+ attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
+ if isinstance(self.model, LongformerEncoderDecoderForConditionalGeneration):
+ attention_mask[:, 0] = 2 # global attention on one token for all model params to be used, which is important for gradient checkpointing to work
+ if self.args.attention_mode == 'sliding_chunks':
+ half_padding_mod = self.model.config.attention_window[0]
+ elif self.args.attention_mode == 'sliding_chunks_no_overlap':
+ half_padding_mod = self.model.config.attention_window[0] / 2
+ else:
+ raise NotImplementedError
+ input_ids, attention_mask = pad_to_window_size( # ideally, should be moved inside the LongformerModel
+ input_ids, attention_mask, half_padding_mod, self.tokenizer.pad_token_id)
+ return input_ids, attention_mask
+
+ def forward(self, input_ids, output_ids):
+ input_ids, attention_mask = self._prepare_input(input_ids)
+ decoder_input_ids = output_ids[:, :-1]
+ decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id)
+ labels = output_ids[:, 1:].clone()
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ use_cache=False,)
+ lm_logits = outputs[0]
+ if self.args.label_smoothing == 0:
+ # Same behavior as modeling_bart.py, besides ignoring pad_token_id
+ ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
+ assert lm_logits.shape[-1] == self.model.config.vocab_size
+ loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
+ else:
+ lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
+ loss, nll_loss = label_smoothed_nll_loss(
+ lprobs, labels, self.args.label_smoothing, ignore_index=self.tokenizer.pad_token_id
+ )
+ return [loss]
+
+ def training_step(self, batch, batch_nb):
+ output = self.forward(*batch)
+ loss = output[0]
+ lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr']
+ tensorboard_logs = {'train_loss': loss, 'lr': lr,
+ 'input_size': batch[0].numel(),
+ 'output_size': batch[1].numel(),
+ 'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0}
+ return {'loss': loss, 'log': tensorboard_logs}
+
+ def validation_step(self, batch, batch_nb):
+ for p in self.model.parameters():
+ p.requires_grad = False
+
+ outputs = self.forward(*batch)
+ vloss = outputs[0]
+ input_ids, output_ids = batch
+ input_ids, attention_mask = self._prepare_input(input_ids)
+ generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
+ use_cache=True, max_length=self.args.max_output_len,
+ num_beams=1)
+ generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
+ gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True)
+ scorer = rouge_scorer.RougeScorer(rouge_types=['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=False)
+ rouge1 = rouge2 = rougel = rougelsum = 0.0
+ for ref, pred in zip(gold_str, generated_str):
+ score = scorer.score(ref, pred)
+ rouge1 += score['rouge1'].fmeasure
+ rouge2 += score['rouge2'].fmeasure
+ rougel += score['rougeL'].fmeasure
+ rougelsum += score['rougeLsum'].fmeasure
+ rouge1 /= len(generated_str)
+ rouge2 /= len(generated_str)
+ rougel /= len(generated_str)
+ rougelsum /= len(generated_str)
+
+ return {'vloss': vloss,
+ 'rouge1': vloss.new_zeros(1) + rouge1,
+ 'rouge2': vloss.new_zeros(1) + rouge2,
+ 'rougeL': vloss.new_zeros(1) + rougel,
+ 'rougeLsum': vloss.new_zeros(1) + rougelsum, }
+
+ def validation_epoch_end(self, outputs):
+ for p in self.model.parameters():
+ p.requires_grad = True
+
+ names = ['vloss', 'rouge1', 'rouge2', 'rougeL', 'rougeLsum']
+ metrics = []
+ for name in names:
+ metric = torch.stack([x[name] for x in outputs]).mean()
+ if self.trainer.use_ddp:
+ torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
+ metric /= self.trainer.world_size
+ metrics.append(metric)
+ logs = dict(zip(*[names, metrics]))
+ print(logs)
+ return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs}
+
+ def test_step(self, batch, batch_nb):
+ return self.validation_step(batch, batch_nb)
+
+ def test_epoch_end(self, outputs):
+ result = self.validation_epoch_end(outputs)
+ print(result)
+
+ def configure_optimizers(self):
+ if self.args.adafactor:
+ optimizer = Adafactor(self.model.parameters(), lr=self.args.lr, scale_parameter=False, relative_step=False)
+ else:
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
+ if self.args.debug:
+ return optimizer # const LR
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
+ num_steps = self.args.dataset_size * self.args.epochs / num_gpus / self.args.grad_accum / self.args.batch_size
+ scheduler = get_linear_schedule_with_warmup(
+ optimizer, num_warmup_steps=self.args.warmup, num_training_steps=num_steps
+ )
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
+
+ def _get_dataloader(self, current_dataloader, split_name, is_train):
+ if current_dataloader is not None:
+ return current_dataloader
+ dataset = SummarizationDataset(hf_dataset=self.hf_datasets[split_name], tokenizer=self.tokenizer,
+ max_input_len=self.args.max_input_len, max_output_len=self.args.max_output_len)
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None
+ return DataLoader(dataset, batch_size=self.args.batch_size, shuffle=(sampler is None),
+ num_workers=self.args.num_workers, sampler=sampler,
+ collate_fn=SummarizationDataset.collate_fn)
+
+ @pl.data_loader
+ def train_dataloader(self):
+ self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True)
+ return self.train_dataloader_object
+
+ @pl.data_loader
+ def val_dataloader(self):
+ self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, 'validation', is_train=False)
+ return self.val_dataloader_object
+
+ @pl.data_loader
+ def test_dataloader(self):
+ self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False)
+ return self.test_dataloader_object
+
+ def configure_ddp(self, model, device_ids):
+ model = LightningDistributedDataParallel(
+ model,
+ device_ids=device_ids,
+ find_unused_parameters=False
+ )
+ return model
+
+ @staticmethod
+ def add_model_specific_args(parser, root_dir):
+ parser.add_argument("--save_dir", type=str, default='summarization')
+ parser.add_argument("--save_prefix", type=str, default='test')
+ parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
+ parser.add_argument("--grad_accum", type=int, default=1, help="number of gradient accumulation steps")
+ parser.add_argument("--gpus", type=int, default=-1,
+ help="Number of gpus. 0 for CPU")
+ parser.add_argument("--warmup", type=int, default=1000, help="Number of warmup steps")
+ parser.add_argument("--lr", type=float, default=0.00003, help="Maximum learning rate")
+ parser.add_argument("--val_every", type=float, default=1.0, help="Number of training steps between validations")
+ parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used')
+ parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers")
+ parser.add_argument("--seed", type=int, default=1234, help="Seed")
+ parser.add_argument("--epochs", type=int, default=5, help="Number of epochs")
+ parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing")
+ parser.add_argument("--max_output_len", type=int, default=256,
+ help="maximum num of wordpieces/summary. Used for training and testing")
+ parser.add_argument("--max_input_len", type=int, default=512,
+ help="maximum num of wordpieces/summary. Used for training and testing")
+ parser.add_argument("--test", action='store_true', help="Test only, no training")
+ parser.add_argument("--model_path", type=str, default='facebook/bart-base',
+ help="Path to the checkpoint directory or model name")
+ parser.add_argument("--tokenizer", type=str, default='facebook/bart-base')
+ parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing")
+ parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32")
+ parser.add_argument("--debug", action='store_true', help="debug run")
+ parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from")
+ parser.add_argument('--grad_ckpt', action='store_true', help='Enable gradient checkpointing to save memory')
+ parser.add_argument("--attention_dropout", type=float, default=0.1, help="attention dropout")
+ parser.add_argument("--attention_mode", type=str, default='sliding_chunks', help="Longformer attention mode")
+ parser.add_argument("--attention_window", type=int, default=512, help="Attention window")
+ parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
+ parser.add_argument("--adafactor", action='store_true', help="Use adafactor optimizer")
+
+ return parser
+
+
+def main(args):
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+
+ model = Summarizer(args)
+ model.hf_datasets = nlp.load_dataset('scientific_papers', 'arxiv')
+
+ logger = TestTubeLogger(
+ save_dir=args.save_dir,
+ name=args.save_prefix,
+ version=0 # always use version=0
+ )
+
+ checkpoint_callback = ModelCheckpoint(
+ filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"),
+ save_top_k=5,
+ verbose=True,
+ monitor='avg_val_loss',
+ mode='min',
+ period=-1,
+ prefix=''
+ )
+
+ print(args)
+
+ args.dataset_size = 203037 # hardcode dataset size. Needed to compute number of steps for the lr scheduler
+
+ trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if torch.cuda.is_available() else None,
+ track_grad_norm=-1,
+ max_epochs=args.epochs if not args.debug else 100,
+ max_steps=None if not args.debug else 1,
+ replace_sampler_ddp=False,
+ accumulate_grad_batches=args.grad_accum,
+ val_check_interval=args.val_every if not args.debug else 1,
+ num_sanity_val_steps=2 if not args.debug else 0,
+ check_val_every_n_epoch=1 if not args.debug else 1,
+ val_percent_check=args.val_percent_check,
+ test_percent_check=args.val_percent_check,
+ logger=logger,
+ checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False,
+ show_progress_bar=not args.no_progress_bar,
+ use_amp=not args.fp32, amp_level='O2',
+ resume_from_checkpoint=args.resume_ckpt,
+ )
+ if not args.test:
+ trainer.fit(model)
+ trainer.test(model)
+
+
+if __name__ == "__main__":
+ main_arg_parser = argparse.ArgumentParser(description="summarization")
+ parser = Summarizer.add_model_specific_args(main_arg_parser, os.getcwd())
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/test_tpu.py b/scripts/test_tpu.py
new file mode 100644
index 0000000..e692890
--- /dev/null
+++ b/scripts/test_tpu.py
@@ -0,0 +1,44 @@
+import torch
+from torch.utils.data import DataLoader, Dataset
+from transformers import AutoModel
+import pytorch_lightning as pl
+
+
+class CoolDataset(Dataset):
+
+ def __len__(self):
+ return 128 * 128
+
+ def __getitem__(self, idx):
+ return torch.tensor([1, 2, 3, 4] * 128 * 8), torch.tensor([1, 1, 1, 1] * 128 * 8)
+
+
+class CoolSystem(pl.LightningModule):
+
+ def __init__(self):
+ super().__init__()
+
+ self.model = AutoModel.from_pretrained('allenai/longformer-base-4096')
+ # self.model = AutoModel.from_pretrained('roberta-base')
+
+ def forward(self, x, y):
+ return self.model(x, attention_mask=None)
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ y_hat = self(x, y)
+ loss = y_hat[0].sum()
+ return {'loss': loss}
+
+ def configure_optimizers(self):
+ return torch.optim.Adam(self.parameters(), lr=0.001)
+
+ def train_dataloader(self):
+ loader = DataLoader(CoolDataset(), batch_size=1, num_workers=0)
+ return loader
+
+
+if __name__ == '__main__':
+ model = CoolSystem()
+ trainer = pl.Trainer(num_tpu_cores=8, progress_bar_refresh_rate=1, max_epochs=10, num_sanity_val_steps=0, gpus=0)
+ trainer.fit(model)
diff --git a/scripts/triviaqa.py b/scripts/triviaqa.py
index 281c297..967f97a 100644
--- a/scripts/triviaqa.py
+++ b/scripts/triviaqa.py
@@ -9,7 +9,7 @@
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
-from transformers import RobertaTokenizer
+from transformers import RobertaTokenizer, AutoModel, AutoConfig, AutoModelWithLMHead
from scripts.triviaqa_utils import evaluation_utils
import pytorch_lightning as pl
@@ -110,11 +110,13 @@ def is_whitespace(c):
try:
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset + answer_length - 1]
- except:
+ token_ids = self.tokenizer.encode(orig_answer_text)
+ except RuntimeError:
print(f'Reading example {idx} failed')
start_position = 0
end_position = 0
- answer_spans.append({'start': start_position, 'end': end_position})
+ answer_spans.append({'start': start_position, 'end': end_position,
+ 'text': orig_answer_text, 'token_ids': token_ids})
# ===== Given an example, convert it into tensors =============
query_tokens = self.tokenizer.tokenize(question_text)
@@ -146,6 +148,7 @@ def is_whitespace(c):
segment_ids_list = []
start_positions_list = []
end_positions_list = []
+ answer_token_ids_list = []
for slice_start in range(0, len(all_doc_tokens), max_tokens_per_doc_slice - self.doc_stride):
slice_end = min(slice_start + max_tokens_per_doc_slice, len(all_doc_tokens))
@@ -172,6 +175,7 @@ def is_whitespace(c):
doc_offset = len(query_tokens) + 2 - slice_start
start_positions = []
end_positions = []
+ answer_token_ids = []
for answer_span in answer_spans:
start_position = answer_span['start']
end_position = answer_span['end']
@@ -183,6 +187,7 @@ def is_whitespace(c):
continue
start_positions.append(tok_start_position_in_doc + doc_offset)
end_positions.append(tok_end_position_in_doc + doc_offset)
+ answer_token_ids.append(answer_span['token_ids'])
assert len(start_positions) == len(end_positions)
if self.ignore_seq_with_no_answers and len(start_positions) == 0:
continue
@@ -190,32 +195,58 @@ def is_whitespace(c):
# answers from start_positions and end_positions if > self.max_num_answers
start_positions = start_positions[:self.max_num_answers]
end_positions = end_positions[:self.max_num_answers]
+ answer_token_ids = answer_token_ids[:self.max_num_answers]
# -1 padding up to self.max_num_answers
padding_len = self.max_num_answers - len(start_positions)
start_positions.extend([-1] * padding_len)
end_positions.extend([-1] * padding_len)
+ answer_token_ids.extend([[]] * padding_len)
# replace duplicate start/end positions with `-1` because duplicates can result into -ve loss values
found_start_positions = set()
found_end_positions = set()
- for i, (start_position, end_position) in enumerate(zip(start_positions, end_positions)):
+ found_answer_token_ids = set()
+ for i, (start_position, end_position, answer_tokens) in enumerate(
+ zip(start_positions, end_positions, answer_token_ids)
+ ):
if start_position in found_start_positions:
start_positions[i] = -1
if end_position in found_end_positions:
end_positions[i] = -1
+ answer_tokens_as_str = ','.join([str(x) for x in answer_tokens])
+ if answer_tokens_as_str in found_answer_token_ids:
+ answer_token_ids[i] = []
found_start_positions.add(start_position)
found_end_positions.add(end_position)
+ found_answer_token_ids.add(answer_tokens_as_str)
input_ids_list.append(input_ids)
input_mask_list.append(input_mask)
segment_ids_list.append(segment_ids)
start_positions_list.append(start_positions)
end_positions_list.append(end_positions)
+ answer_token_ids_list.append(answer_token_ids)
+
+ # pad answers in answer_token_ids_list to the longest answer
+ max_answer_len = max([len(item) for sublist in answer_token_ids_list for item in sublist]) # flat list
+ if max_answer_len == 0:
+ max_answer_len = 2
+ for answers_of_one_slice in answer_token_ids_list:
+ for answer_tokens in answers_of_one_slice:
+ if len(answer_tokens) == 0:
+ # TODO: or ?
+ padding_len = max_answer_len - len(answer_tokens) - 2
+ answer_tokens.extend([self.tokenizer.bos_token_id, self.tokenizer.eos_token_id] +
+ ([self.tokenizer.pad_token_id] * padding_len))
+ else:
+ padding_len = max_answer_len - len(answer_tokens)
+ answer_tokens.extend([self.tokenizer.pad_token_id] * padding_len)
tensors_list.append((torch.tensor(input_ids_list), torch.tensor(input_mask_list),
torch.tensor(segment_ids_list),
torch.tensor(start_positions_list), torch.tensor(end_positions_list),
+ torch.tensor(answer_token_ids_list),
self._get_qid(qa['id']), qa["aliases"])) # for eval
return tensors_list
@@ -259,14 +290,39 @@ def __init__(self, args):
self.tokenizer.model_max_length = self.args.max_seq_len
self.model = self.load_model()
self.num_labels = 2
- self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels)
+ if not self.args.seq2seq:
+ self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels)
self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None
def load_model(self):
- model = Longformer.from_pretrained(self.args.model_path)
- for layer in model.encoder.layer:
- layer.attention.self.attention_mode = self.args.attention_mode
- self.args.attention_window = layer.attention.self.attention_window
+ if 'longformer' in self.args.model_path:
+ model = Longformer.from_pretrained(self.args.model_path)
+ for layer in model.encoder.layer:
+ layer.attention.self.attention_mode = self.args.attention_mode
+ self.args.attention_window = layer.attention.self.attention_window
+ elif self.args.model_path in ['bart.large', 'bart.base']:
+ model = torch.hub.load('pytorch/fairseq', self.args.model_path)
+ model.config = model.args
+ model.config.hidden_size = model.config.decoder_output_dim
+ elif 'bart' in self.args.model_path and 'base' in self.args.model_path:
+ config = AutoConfig.from_pretrained(self.args.model_path)
+ config.encoder_attention_heads = 12
+ config.decoder_attention_heads = 12
+ config.attention_dropout = 0.1
+ if self.args.seq2seq:
+ model = AutoModelWithLMHead.from_pretrained(self.args.model_path, config=config)
+ else:
+ model = AutoModel.from_pretrained(self.args.model_path, config=config)
+ elif 'bart' in self.args.model_path and 'large' in self.args.model_path:
+ config = AutoConfig.from_pretrained(self.args.model_path)
+ config.attention_dropout = 0.1
+ config.gradient_checkpointing = True
+ if self.args.seq2seq:
+ model = AutoModelWithLMHead.from_pretrained(self.args.model_path, config=config)
+ else:
+ model = AutoModel.from_pretrained(self.args.model_path, config=config)
+ else:
+ model = AutoModel.from_pretrained(self.args.model_path)
print("Loaded model with config:")
print(model.config)
@@ -276,30 +332,51 @@ def load_model(self):
model.train()
return model
- def forward(self, input_ids, attention_mask, segment_ids, start_positions, end_positions):
- question_end_index = self._get_question_end_index(input_ids)
- # Each batch is one document, and each row of the batch is a chunck of the document.
- # Make sure all rows have the same question length.
- assert (question_end_index[0].float() == question_end_index.float().mean()).item()
-
- # local attention everywhere
- attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
- # global attention for the question tokens
- attention_mask[:, :question_end_index.item()] = 2
-
- # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size
- input_ids, attention_mask = pad_to_window_size(
- input_ids, attention_mask, self.args.attention_window, self.tokenizer.pad_token_id)
-
- sequence_output = self.model(
- input_ids,
- attention_mask=attention_mask)[0]
-
- # The pretrained TriviaQA model wasn't trained with padding, so remove padding tokens
- # before computing loss and decoding.
- padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum()
- if padding_len > 0:
- sequence_output = sequence_output[:, :-padding_len]
+ def forward(self, input_ids, attention_mask, segment_ids, start_positions, end_positions, answer_token_ids):
+ if 'longformer' in self.args.model_path:
+ question_end_index = self._get_question_end_index(input_ids)
+ # Each batch is one document, and each row of the batch is a chunck of the document.
+ # Make sure all rows have the same question length.
+ assert (question_end_index[0].float() == question_end_index.float().mean()).item()
+
+ # local attention everywhere
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
+ # global attention for the question tokens
+ attention_mask[:, :question_end_index.item()] = 2
+
+ # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size
+ input_ids, attention_mask = pad_to_window_size(
+ input_ids, attention_mask, self.args.attention_window, self.tokenizer.pad_token_id)
+
+ sequence_output = self.model(
+ input_ids,
+ attention_mask=attention_mask)[0]
+
+ # The pretrained TriviaQA model wasn't trained with padding, so remove padding tokens
+ # before computing loss and decoding.
+ padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum()
+ if padding_len > 0:
+ sequence_output = sequence_output[:, :-padding_len]
+ elif self.args.model_path in ['bart.large', 'bart.base']:
+ sequence_output = self.model.extract_features(input_ids)
+ else:
+ if self.args.seq2seq:
+ decoder_input_ids = answer_token_ids[:, 0, :-1].clone()
+ decoder_input_ids[decoder_input_ids == self.tokenizer.eos_token_id] = self.tokenizer.pad_token_id
+ decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id)
+ labels = answer_token_ids[:, 0, 1:].contiguous()
+ labels[answer_token_ids[:, 0, 1:] == self.tokenizer.pad_token_id] = -100
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ labels=labels)
+ loss = outputs[0]
+ logit_scores = outputs[1].softmax(dim=2)[:, :, 0].sum(dim=1)
+ return [loss, logit_scores]
+ else:
+ sequence_output = self.model(input_ids, attention_mask=attention_mask)[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
@@ -368,8 +445,8 @@ def or_softmax_cross_entropy_loss_one_doc(self, logits, target, ignore_index=-1,
return loss[~torch.isinf(loss)].sum()
def training_step(self, batch, batch_nb):
- input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch
- output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends)
+ input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids, qids, aliases = batch
+ output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids)
loss = output[0]
lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr']
tensorboard_logs = {'train_loss': loss, 'lr': lr,
@@ -378,8 +455,29 @@ def training_step(self, batch, batch_nb):
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_nb):
- input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch
- output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends)
+ input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids, qids, aliases = batch
+ output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids)
+ if self.args.seq2seq:
+ logit_scores = output[1]
+ answer_score_indices = logit_scores.sort().indices
+ generated_ids = self.model.generate(input_ids=input_ids, attention_mask=input_mask, use_cache=True,)
+ answer_text = ''
+ best_answer_score = 0
+ for i in answer_score_indices:
+ generated_answer_ids = generated_ids[answer_score_indices[i]]
+ generated_answer_ids[-1] = self.tokenizer.eos_token_id
+ index_of_eos_token = (generated_answer_ids == self.tokenizer.eos_token_id).nonzero()[0, 0].item()
+ generated_answer_ids = generated_answer_ids[1:index_of_eos_token]
+ answer_text = self.tokenizer.decode(generated_answer_ids)
+ if answer_text != '':
+ best_answer_score = logit_scores[answer_score_indices[i]]
+ break
+ f1_score = evaluation_utils.metric_max_over_ground_truths(evaluation_utils.f1_score, answer_text, aliases)
+ em_score = evaluation_utils.metric_max_over_ground_truths(evaluation_utils.exact_match_score, answer_text, aliases)
+ return {'vloss': output[0], 'vem': generated_answer_ids.new_zeros([1]).float(),
+ 'qids': [qids], 'answer_scores': [best_answer_score],
+ 'f1': [f1_score], 'em': [em_score]}
+
loss, start_logits, end_logits = output[:3]
answers = self.decode(input_ids, start_logits, end_logits)
@@ -453,8 +551,8 @@ def decode(self, input_ids, start_logits, end_logits):
answers.append({'text': text, 'score': score})
return answers
- def sync_list_across_gpus(self, l, device, dtype):
- l_tensor = torch.tensor(l, device=device, dtype=dtype)
+ def sync_list_across_gpus(self, list_to_sync, device, dtype):
+ l_tensor = torch.tensor(list_to_sync, device=device, dtype=dtype)
gather_l_tensor = [torch.ones_like(l_tensor) for _ in range(self.trainer.world_size)]
torch.distributed.all_gather(gather_l_tensor, l_tensor)
return torch.cat(gather_l_tensor).tolist()
@@ -499,8 +597,11 @@ def validation_end(self, outputs):
return {'avg_val_loss': avg_loss, 'log': logs, 'progress_bar': logs}
def test_step(self, batch, batch_nb):
- input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch
- output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends)
+ input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids, qids, aliases = batch
+ output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids)
+ if self.args.seq2seq:
+ raise NotImplemented
+
loss, start_logits, end_logits = output[:3]
answers = self.decode(input_ids, start_logits, end_logits)
@@ -528,21 +629,14 @@ def test_end(self, outputs):
return {'count': len(qid_to_answer_text)}
- def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
- optimizer.step()
- optimizer.zero_grad()
- self.scheduler.step(self.global_step)
-
def configure_optimizers(self):
def lr_lambda(current_step):
if current_step < self.args.warmup:
return float(current_step) / float(max(1, self.args.warmup))
return max(0.0, float(self.args.steps - current_step) / float(max(1, self.args.steps - self.args.warmup)))
optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
- self.scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) # scheduler is not saved in the checkpoint, but global_step is, which is enough to restart
- self.scheduler.step(self.global_step)
-
- return optimizer
+ scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1)
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
@pl.data_loader
def train_dataloader(self):
@@ -554,7 +648,7 @@ def train_dataloader(self):
max_num_answers=self.args.max_num_answers,
max_question_len=self.args.max_question_len,
ignore_seq_with_no_answers=self.args.ignore_seq_with_no_answers)
- sampler = torch.utils.data.distributed.DistributedSampler(dataset) if self.trainer.use_ddp else None
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) if self.trainer.use_ddp else None
dl = DataLoader(dataset, batch_size=1, shuffle=(sampler is None),
num_workers=self.args.num_workers, sampler=sampler,
collate_fn=TriviaQADataset.collate_one_doc_and_lists)
@@ -571,8 +665,8 @@ def val_dataloader(self):
max_num_answers=self.args.max_num_answers,
max_question_len=self.args.max_question_len,
ignore_seq_with_no_answers=False) # evaluation data should keep all examples
- sampler = torch.utils.data.distributed.DistributedSampler(dataset) if self.trainer.use_ddp else None
- dl = DataLoader(dataset, batch_size=1, shuffle=(sampler is None),
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) if self.trainer.use_ddp else None
+ dl = DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=self.args.num_workers, sampler=sampler,
collate_fn=TriviaQADataset.collate_one_doc_and_lists)
self.val_dataloader_object = dl
@@ -599,7 +693,7 @@ def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
- find_unused_parameters=True
+ find_unused_parameters=False
)
return model
@@ -610,11 +704,11 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument("--train_dataset", type=str, required=False, help="Path to the training squad-format")
parser.add_argument("--dev_dataset", type=str, required=True, help="Path to the dev squad-format")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
- parser.add_argument("--gpus", type=str, default='0',
- help="Comma separated list of gpus. Default is gpu 0. To use CPU, use --gpus "" ")
+ parser.add_argument("--gpus", type=int, default=1,
+ help="Number of gpus. 0 for CPU")
parser.add_argument("--warmup", type=int, default=200, help="Number of warmup steps")
parser.add_argument("--lr", type=float, default=0.0001, help="Maximum learning rate")
- parser.add_argument("--val_every", type=float, default=0.2, help="Number of training steps between validations")
+ parser.add_argument("--val_every", type=float, default=0.5, help="Number of training steps between validations")
parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used')
parser.add_argument("--num_workers", type=int, default=4, help="Number of data loader workers")
parser.add_argument("--seed", type=int, default=1234, help="Seed")
@@ -636,7 +730,8 @@ def add_model_specific_args(parser, root_dir):
help="Number of answer candidates. Used at decoding time")
parser.add_argument("--max_answer_length", type=int, default=30,
help="maximum num of wordpieces/answer. Used at decoding time")
- parser.add_argument("--regular_softmax_loss", action='store_true', help="IF true, use regular softmax. Default is using ORed softmax loss")
+ parser.add_argument("--regular_softmax_loss", action='store_true',
+ help="IF true, use regular softmax. Default is using ORed softmax loss")
parser.add_argument("--test", action='store_true', help="Test only, no training")
parser.add_argument("--model_path", type=str, required=True,
help="Path to the checkpoint directory")
@@ -644,6 +739,9 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument("--attention_mode", type=str, choices=['tvm', 'sliding_chunks'],
default='sliding_chunks', help='Which implementation of selfattention to use')
parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32")
+ parser.add_argument("--seq2seq", action='store_true', help="Use an answer generation model")
+ parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from")
+
return parser
@@ -667,28 +765,32 @@ def main(args):
filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"),
save_top_k=5,
verbose=True,
- monitor='avg_val_f1',
- mode='max',
+ monitor='avg_val_loss',
+ # save_last=True,
+ mode='min',
+ period=-1,
prefix=''
)
- args.gpus = [int(x) for x in args.gpus.split(',')] if args.gpus is not "" else None # use CPU if no gpu provided
print(args)
train_set_size = 110648 # hardcode dataset size. Needed to compute number of steps for the lr scheduler
- num_devices = 1 or len(args.gpus)
- args.steps = args.epochs * train_set_size / (args.batch_size * num_devices)
- print(f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * num_devices} <<<<<<<')
+ args.steps = args.epochs * train_set_size / (args.batch_size * max(args.gpus, 1))
+ print(f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * args.gpus} <<<<<<<')
- trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if args.gpus and (len(args.gpus) > 1) else None,
- track_grad_norm=-1, max_nb_epochs=args.epochs, early_stop_callback=None,
+ trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if args.gpus and args.gpus > 1 else None,
+ track_grad_norm=-1, max_epochs=args.epochs, early_stop_callback=None,
+ replace_sampler_ddp=False,
accumulate_grad_batches=args.batch_size,
val_check_interval=args.val_every,
+ num_sanity_val_steps=2,
+ # check_val_every_n_epoch=2,
val_percent_check=args.val_percent_check,
test_percent_check=args.val_percent_check,
logger=logger if not args.disable_checkpointing else False,
checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False,
show_progress_bar=not args.no_progress_bar,
use_amp=not args.fp32, amp_level='O2',
+ resume_from_checkpoint=args.resume_ckpt,
)
if not args.test:
trainer.fit(model)