diff --git a/README.md b/README.md index 466500d..aed7e73 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,24 @@ #

`Longformer`

`Longformer` is a BERT-like model for long documents. + +**\*\*\*\*\* Work In Progress: LongformerEncoderDecoder \*\*\*\*\*** + +A `LongformerEncoderDecoder` model is now available. It is geared towards summarization where the input is long but the output is relatively shorter. The following code snippet loads a `LongformerEncoderDecoder` checkpointing started from `BART`. With gradient checkpointing, fp16, and 48GB gpu, the input length can be up to 16K tokens. +``` +pip install git+https://github.com/allenai/longformer.git@encoderdecoder + +# checkpoint-base: https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-base-16384.tar.gz +# checkpoint-large: https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-large-16384.tar.gz + +from longformer import LongformerEncoderDecoderForConditionalGeneration +model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(downloaded_checkpoint, gradient_checkpointing=True) +``` + +- Check the script `scripts/summarization.py` for an example of how to use the model. + +- Make sure to use the huggingface/transformers fork specified in `requirements.txt`. + **\*\*\*\*\* New July 23rd, 2020: Speed degradation \*\*\*\*\*** A significant speed degradation in the hugginface/transformers was recenlty discovered and fixed (check [this PR](https://github.com/huggingface/transformers/pull/5811) for details). To avoid this problem, either use the old [release v2.11.0](https://github.com/huggingface/transformers/tree/v2.11.0) but it doesn't support gradient checkpointing, or use the master branch. This problem should be fixed with the next hugginface/transformers release. diff --git a/experiment.yml b/experiment.yml new file mode 100644 index 0000000..156faf5 --- /dev/null +++ b/experiment.yml @@ -0,0 +1,18 @@ +tasks: + - cluster: {{.Env.CLUSTER}} + spec: + # This is a python3.7/nvidia base image with basic libraries + image: im_j69gti4atcw9 + resultPath: {{.Env.RESULT_PATH}} + args: + - /bin/bash + - -c + - "cd /longformer_on_beaker && pip install . && {{.Env.ARGS}}" + datasetMounts: + - datasetId: {{.Env.INPUT_DATASET_ID}} + containerPath: /data + - datasetId: {{.Env.SCRIPTS}} + containerPath: /longformer_on_beaker + requirements: + gpuCount: {{.Env.GPU_COUNT}} + cpu: {{.Env.CPU_COUNT}} diff --git a/longformer/__init__.py b/longformer/__init__.py index e69de29..d3e343c 100644 --- a/longformer/__init__.py +++ b/longformer/__init__.py @@ -0,0 +1,3 @@ +from longformer.longformer import Longformer, LongformerForMaskedLM, LongformerConfig +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration \ No newline at end of file diff --git a/longformer/longformer.py b/longformer/longformer.py index 953bd2c..ecf19d1 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv +from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM @@ -48,7 +49,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[ self.attention_dilation = attention_dilation self.autoregressive = autoregressive self.attention_mode = attention_mode - assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap'] class LongformerSelfAttention(nn.Module): @@ -58,7 +59,6 @@ def __init__(self, config, layer_id): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) - self.output_attentions = config.output_attentions self.num_heads = config.num_attention_heads self.head_dim = int(config.hidden_size / config.num_attention_heads) self.embed_dim = config.hidden_size @@ -80,8 +80,8 @@ def __init__(self, config, layer_id): self.autoregressive = config.autoregressive assert self.attention_window > 0 assert self.attention_dilation > 0 - assert self.attention_mode in ['tvm', 'sliding_chunks'] - if self.attention_mode == 'sliding_chunks': + assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap'] + if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']: assert not self.autoregressive # not supported assert self.attention_dilation == 1 # dilation is not supported @@ -147,8 +147,12 @@ def forward( q = q.float().contiguous() k = k.float().contiguous() attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) - else: # "sliding_chunks" + elif self.attention_mode == "sliding_chunks": attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) + elif self.attention_mode == "sliding_chunks_no_overlap": + attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) + else: + raise False mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) if remove_from_windowed_attention_mask is not None: # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 @@ -162,10 +166,14 @@ def forward( # diagonal mask with zeros everywhere and -inf inplace of padding if self.attention_mode == 'tvm': d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) - else: + elif self.attention_mode == "sliding_chunks": d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) + elif self.attention_mode == "sliding_chunks_no_overlap": + d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) + attn_weights += d_mask - assert list(attn_weights.size()) == [bsz, seq_len, self.num_heads, self.attention_window * 2 + 1] + assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] + assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3] # the extra attention if extra_attention_mask is not None: @@ -182,7 +190,6 @@ def forward( if key_padding_mask is not None: # softmax sometimes inserts NaN if all positions are masked, replace them with 0 attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) - attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) @@ -199,8 +206,12 @@ def forward( if self.attention_mode == 'tvm': v = v.float().contiguous() attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) - else: # "sliding_chunks" + elif self.attention_mode == "sliding_chunks": attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) + elif self.attention_mode == "sliding_chunks_no_overlap": + attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window) + else: + raise False attn = attn.type_as(hidden_states) assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] diff --git a/longformer/longformer_encoder_decoder.py b/longformer/longformer_encoder_decoder.py new file mode 100644 index 0000000..df38224 --- /dev/null +++ b/longformer/longformer_encoder_decoder.py @@ -0,0 +1,76 @@ +from typing import List, Optional, Tuple, Dict +from torch import nn, Tensor +from longformer.longformer import LongformerSelfAttention +from transformers.modeling_bart import BartConfig, BartForConditionalGeneration + + +class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + if config.attention_mode == 'n2': + pass # do nothing, use BertSelfAttention instead + else: + for i, layer in enumerate(self.model.encoder.layers): + layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i) + + +class LongformerEncoderDecoderConfig(BartConfig): + def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, + autoregressive: bool = False, attention_mode: str = 'sliding_chunks', + gradient_checkpointing: bool = False, **kwargs): + """ + Args: + attention_window: list of attention window sizes of length = number of layers. + window size = number of attention locations on each side. + For an affective window size of 512, use `attention_window=[256]*num_layers` + which is 256 on each side. + attention_dilation: list of attention dilation of length = number of layers. + attention dilation of `1` means no dilation. + autoregressive: do autoregressive attention or have attention of both sides + attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer + selfattention, 'sliding_chunks' for another implementation of Longformer selfattention + """ + super().__init__(**kwargs) + self.attention_window = attention_window + self.attention_dilation = attention_dilation + self.autoregressive = autoregressive + self.attention_mode = attention_mode + self.gradient_checkpointing = gradient_checkpointing + assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + + +class LongformerSelfAttentionForBart(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.embed_dim = config.d_model + self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id) + self.output = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + query, + key: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + layer_state: Optional[Dict[str, Optional[Tensor]]] = None, + attn_mask: Optional[Tensor] = None, + need_weights=False, + output_attentions=False, + ) -> Tuple[Tensor, Optional[Tensor]]: + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + assert attn_mask is None + + outputs = self.longformer_self_attn( + query.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim) + attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=output_attentions, + ) + + attn_output = self.output(outputs[0].transpose(0, 1)) + + return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) diff --git a/longformer/longformer_t5_encoder_decoder.py b/longformer/longformer_t5_encoder_decoder.py new file mode 100644 index 0000000..85cf786 --- /dev/null +++ b/longformer/longformer_t5_encoder_decoder.py @@ -0,0 +1,379 @@ +import math +from typing import List, Optional, Tuple, Dict +from torch import nn, Tensor +from longformer.longformer import LongformerSelfAttention +from longformer.sliding_chunks import * +from transformers.modeling_t5 import T5Config, T5ForConditionalGeneration + + +class LongformerEncoderDecoderForConditionalGenerationT5(T5ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + if config.attention_mode == 'n2': + pass # do nothing, use BertSelfAttention instead + else: + for i, layer in enumerate(self.encoder.block): + layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) + + +class LongformerEncoderDecoderConfigT5(T5Config): + def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None, + autoregressive: bool = False, attention_mode: str = 'sliding_chunks', + gradient_checkpointing: bool = False, **kwargs): + """ + Args: + attention_window: list of attention window sizes of length = number of layers. + window size = number of attention locations on each side. + For an affective window size of 512, use `attention_window=[256]*num_layers` + which is 256 on each side. + attention_dilation: list of attention dilation of length = number of layers. + attention dilation of `1` means no dilation. + autoregressive: do autoregressive attention or have attention of both sides + attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer + selfattention, 'sliding_chunks' for another implementation of Longformer selfattention + """ + super().__init__(**kwargs) + self.attention_window = attention_window + self.attention_dilation = attention_dilation + self.autoregressive = autoregressive + self.attention_mode = attention_mode + self.gradient_checkpointing = gradient_checkpointing + self.attention_probs_dropout_prob = self.dropout_rate + assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2'] + +class LongformerSelfAttentionT5Basic(nn.Module): + def __init__(self, config, layer_id, has_relative_attention_bias=False): + super(LongformerSelfAttentionT5Basic, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_heads = config.num_attention_heads + self.head_dim = int(config.hidden_size / config.num_attention_heads) + self.embed_dim = config.hidden_size + + self.query = nn.Linear(config.hidden_size, self.embed_dim) + self.key = nn.Linear(config.hidden_size, self.embed_dim) + self.value = nn.Linear(config.hidden_size, self.embed_dim) + + # this is for the T5 setting + self.is_decoder = config.is_decoder + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.has_relative_attention_bias = has_relative_attention_bias + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.num_heads) + + self.query_global = nn.Linear(config.hidden_size, self.embed_dim) + self.key_global = nn.Linear(config.hidden_size, self.embed_dim) + self.value_global = nn.Linear(config.hidden_size, self.embed_dim) + + self.dropout = config.attention_probs_dropout_prob + + self.layer_id = layer_id + self.attention_window = config.attention_window[self.layer_id] + self.attention_dilation = config.attention_dilation[self.layer_id] + self.attention_mode = config.attention_mode + self.autoregressive = config.autoregressive + assert self.attention_window > 0 + assert self.attention_dilation > 0 + assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap'] + if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']: + assert not self.autoregressive # not supported + assert self.attention_dilation == 1 # dilation is not supported + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_postion_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, qlen, klen): + """ Compute binned relative position bias """ + relative_position = torch.tensor([[i-self.attention_window for i in range(2*self.attention_window+1)]]) + rp_bucket = self._relative_position_bucket( + relative_position, # shape (qlen, klen) + bidirectional=not self.is_decoder, + num_buckets=self.relative_attention_num_buckets, + ) + rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) +# values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) + # Changing the shape to below because that's what LongformerSelfAttention's attn_weights need. + values = values.permute([0, 2, 1]).unsqueeze(0) # shape (1, qlen, num_heads, klen) + return values + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + output_attentions=False, + ): + ''' + The `attention_mask` is changed in `BertModel.forward` from 0, 1, 2 to + -ve: no attention + 0: local attention + +ve: global attention + ''' + if attention_mask is not None: + attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) + key_padding_mask = attention_mask < 0 + extra_attention_mask = attention_mask > 0 + remove_from_windowed_attention_mask = attention_mask != 0 + + num_extra_indices_per_batch = extra_attention_mask.long().sum(dim=1) + max_num_extra_indices_per_batch = num_extra_indices_per_batch.max() + if max_num_extra_indices_per_batch <= 0: + extra_attention_mask = None + else: + # To support the case of variable number of global attention in the rows of a batch, + # we use the following three selection masks to select global attention embeddings + # in a 3d tensor and pad it to `max_num_extra_indices_per_batch` + # 1) selecting embeddings that correspond to global attention + extra_attention_mask_nonzeros = extra_attention_mask.nonzero(as_tuple=True) + zero_to_max_range = torch.arange(0, max_num_extra_indices_per_batch, + device=num_extra_indices_per_batch.device) + # mask indicating which values are actually going to be padding + selection_padding_mask = zero_to_max_range < num_extra_indices_per_batch.unsqueeze(dim=-1) + # 2) location of the non-padding values in the selected global attention + selection_padding_mask_nonzeros = selection_padding_mask.nonzero(as_tuple=True) + # 3) location of the padding values in the selected global attention + selection_padding_mask_zeros = (selection_padding_mask == 0).nonzero(as_tuple=True) + else: + remove_from_windowed_attention_mask = None + extra_attention_mask = None + key_padding_mask = None + + hidden_states = hidden_states.transpose(0, 1) + seq_len, bsz, embed_dim = hidden_states.size() + assert embed_dim == self.embed_dim + q = self.query(hidden_states) + k = self.key(hidden_states) + v = self.value(hidden_states) + q /= math.sqrt(self.head_dim) + + q = q.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + k = k.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + # attn_weights = (bsz, seq_len, num_heads, window*2+1) + if self.attention_mode == 'tvm': + q = q.float().contiguous() + k = k.float().contiguous() + attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False) + elif self.attention_mode == "sliding_chunks": + attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0) + elif self.attention_mode == "sliding_chunks_no_overlap": + attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0) + else: + raise False + mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False) + if remove_from_windowed_attention_mask is not None: + # This implementation is fast and takes very little memory because num_heads x hidden_size = 1 + # from (bsz x seq_len) to (bsz x seq_len x num_heads x hidden_size) + remove_from_windowed_attention_mask = remove_from_windowed_attention_mask.unsqueeze(dim=-1).unsqueeze(dim=-1) + # cast to float/half then replace 1's with -inf + float_mask = remove_from_windowed_attention_mask.type_as(q).masked_fill(remove_from_windowed_attention_mask, -10000.0) + repeat_size = 1 if isinstance(self.attention_dilation, int) else len(self.attention_dilation) + float_mask = float_mask.repeat(1, 1, repeat_size, 1) + ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones + # diagonal mask with zeros everywhere and -inf inplace of padding + if self.attention_mode == 'tvm': + d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False) + elif self.attention_mode == "sliding_chunks": + d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) + elif self.attention_mode == "sliding_chunks_no_overlap": + d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0) + + attn_weights += d_mask + assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads] + assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3] + + # the extra attention + if extra_attention_mask is not None: + selected_k = k.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + selected_k[selection_padding_mask_nonzeros] = k[extra_attention_mask_nonzeros] + # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) + selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, selected_k)) + selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 + # concat to attn_weights + # (bsz, seq_len, num_heads, extra attention count + 2*window+1) + attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) + + if position_bias is None: + if not self.has_relative_attention_bias: + raise ValueError("No position_bias provided and no weights to compute position_bias") + + position_bias = self.compute_bias(seq_len, seq_len) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value_state is not None: + position_bias = position_bias[:, :, -1:, :] + + # TODO: attention_mask should also be the same shape as position_bias. + # Sliding attention window?? + # if attention_mask is not None: + # position_bias = position_bias + attention_mask # (1, num_heads, seq_len, 2*window+1) + + attn_weights += position_bias + + attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + if key_padding_mask is not None: + # softmax sometimes inserts NaN if all positions are masked, replace them with 0 + attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + attn = 0 + if extra_attention_mask is not None: + selected_attn_probs = attn_probs.narrow(-1, 0, max_num_extra_indices_per_batch) + selected_v = v.new_zeros(bsz, max_num_extra_indices_per_batch, self.num_heads, self.head_dim) + selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] + # use `matmul` because `einsum` crashes sometimes with fp16 + # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) + attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) + attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() + + if self.attention_mode == 'tvm': + v = v.float().contiguous() + attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False) + elif self.attention_mode == "sliding_chunks": + attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window) + elif self.attention_mode == "sliding_chunks_no_overlap": + attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window) + else: + raise False + + attn = attn.type_as(hidden_states) + assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim] + attn = attn.transpose(0, 1).reshape(seq_len, bsz, embed_dim).contiguous() + + # For this case, we'll just recompute the attention for these indices + # and overwrite the attn tensor. TODO: remove the redundant computation + if extra_attention_mask is not None: + selected_hidden_states = hidden_states.new_zeros(max_num_extra_indices_per_batch, bsz, embed_dim) + selected_hidden_states[selection_padding_mask_nonzeros[::-1]] = hidden_states[extra_attention_mask_nonzeros[::-1]] + + q = self.query_global(selected_hidden_states) + k = self.key_global(hidden_states) + v = self.value_global(hidden_states) + q /= math.sqrt(self.head_dim) + + q = q.contiguous().view(max_num_extra_indices_per_batch, bsz * self.num_heads, self.head_dim).transpose(0, 1) # (bsz*self.num_heads, max_num_extra_indices_per_batch, head_dim) + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # bsz * self.num_heads, seq_len, head_dim) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_weights.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len] + + attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) + attn_weights[selection_padding_mask_zeros[0], :, selection_padding_mask_zeros[1], :] = -10000.0 + if key_padding_mask is not None: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -10000.0, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, max_num_extra_indices_per_batch, seq_len) + attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + selected_attn = torch.bmm(attn_probs, v) + assert list(selected_attn.size()) == [bsz * self.num_heads, max_num_extra_indices_per_batch, self.head_dim] + + selected_attn_4d = selected_attn.view(bsz, self.num_heads, max_num_extra_indices_per_batch, self.head_dim) + nonzero_selected_attn = selected_attn_4d[selection_padding_mask_nonzeros[0], :, selection_padding_mask_nonzeros[1]] + attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(len(selection_padding_mask_nonzeros[0]), -1).type_as(hidden_states) + + context_layer = attn.transpose(0, 1) + if output_attentions: + if extra_attention_mask is not None: + # With global attention, return global attention probabilities only + # batch_size x num_heads x max_num_global_attention_tokens x sequence_length + # which is the attention weights from tokens with global attention to all tokens + # It doesn't not return local attention + # In case of variable number of global attantion in the rows of a batch, + # attn_weights are padded with -10000.0 attention scores + attn_weights = attn_weights.view(bsz, self.num_heads, max_num_extra_indices_per_batch, seq_len) + else: + # without global attention, return local attention probabilities + # batch_size x num_heads x sequence_length x window_size + # which is the attention weights of every token attending to its neighbours + attn_weights = attn_weights.permute(0, 2, 1, 3) + outputs = (context_layer, attn_weights) if output_attentions else (context_layer,) + return outputs + + +class LongformerSelfAttentionForT5(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.embed_dim = config.d_model + self.longformer_self_attn = LongformerSelfAttentionT5Basic(config, layer_id=layer_id, + has_relative_attention_bias=True) #config.has_relative_attention_bias) + self.output = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + query, + mask=None, + kv=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + outputs = self.longformer_self_attn( + query, #.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim) + #attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1, + attention_mask=mask, #.unsqueeze(dim=1).unsqueeze(dim=1)*-1, + output_attentions=output_attentions, + ) + + attn_output = self.output(outputs[0].transpose(0, 1)) + + return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None) diff --git a/longformer/sliding_chunks.py b/longformer/sliding_chunks.py index d39fe9b..8ee30a1 100644 --- a/longformer/sliding_chunks.py +++ b/longformer/sliding_chunks.py @@ -125,9 +125,52 @@ def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor, Returns (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size ''' - w = 2 * one_sided_window_size + w = int(2 * one_sided_window_size) seqlen = input_ids.size(1) padding_len = (w - seqlen % w) % w input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens return input_ids, attention_mask + + +# ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention ========= +# This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw +# To make this implemenation comparable to "sliding_chunks" set w such that +# w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3 +# For example, +# w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512) +# w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510) +# Performance: +# - Speed: 30% faster than "sliding_chunks" +# - Memory: 95% of the memory usage of "sliding_chunks" +# The windows are asymmetric where number of attention on each side of a token ranges between w to 2w +# while "sliding_chunks" has a symmetric window around each token. +# This implementation is roughly similar to the implementation described in the BigBird paper https://arxiv.org/abs/2007.14062 + +def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): + bsz, seqlen, num_heads, head_dim = q.size() + assert seqlen % w == 0 + assert q.size() == k.size() + # chunk seqlen into non-overlapping chunks of size w + chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim) + chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim) + chunk_k_expanded = torch.stack(( + F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), + chunk_k, + F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), + ), dim=-1) + diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply + return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w) + + +def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): + bsz, seqlen, num_heads, head_dim = v.size() + chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w) + chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim) + chunk_v_extended = torch.stack(( + F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), + chunk_v, + F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), + ), dim=-1) + context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended)) + return context.reshape(bsz, seqlen, num_heads, head_dim) diff --git a/longformer_on_beaker.sh b/longformer_on_beaker.sh new file mode 100755 index 0000000..bedf8d9 --- /dev/null +++ b/longformer_on_beaker.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +export SCRIPTS=$(beaker dataset create -q .) +export INPUT_DATASET_ID="ds_drt127wv4aun" +export RESULT_SAVE_DIR="/runs" +export RESULT_SAVE_PREFIX="test" +export ARGS="$@" +export GPU_COUNT=8 +export CPU_COUNT=32 +export CLUSTER="ai2/on-prem-ai2-server3" +export RESULT_PATH=$RESULT_SAVE_DIR/$RESULT_SAVE_PREFIX + +beaker experiment create -f experiment.yml diff --git a/requirements.txt b/requirements.txt index 5b004e7..3eb9122 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -torch>=1.2.0 -transformers>=3.0.2 +transformers @ git+http://github.com/ibeltagy/transformers.git@longformer_encoder_decoder#egg=transformers +pytorch-lightning @ git+http://github.com/ibeltagy/pytorch-lightning.git@v0.8.5_fixes#egg=pytorch-lightning +torch==1.6.0 tensorboardX -pytorch-lightning==0.6.0 test-tube==0.7.5 +nlp==0.3.0 +rouge_score diff --git a/scripts/cheatsheet.txt b/scripts/cheatsheet.txt index be4fc3a..c0ab4e5 100644 --- a/scripts/cheatsheet.txt +++ b/scripts/cheatsheet.txt @@ -70,3 +70,18 @@ python -m scripts.triviaqa_utils.evaluation_utils \ --prediction_file predictions.json # Output should be: {'exact_match': 73.07644188665083, 'f1': 77.78523804802242, 'common': 7993, 'denominator': 7993, 'pred_len': 7993, 'gold_len': 7993} + + +# TPU +import torch_xla.debug.metrics as met; print(met.metrics_report()) +curl -X POST http://10.125.212.42:8475/requestversion/pytorch-dev20200722 + +/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py --outfile debug.tar.gz -- python -u scripts/test_tpu.py + +/usr/share/torch-xla-nightly/pytorch/xla/scripts/debug_run.py --outfile debug.tar.gz -- python -u scripts/pretrain.py --input_dir data/ --save_prefix test_xla_2 --gpu_count 0 --tpu_core_count 1 --val_batches 4 --val_every 130 --num_workers 0 --log_rate 1 --model allenai/longformer-base-4096 + +python scripts/pretrain.py --input_dir data/ --save_prefix test_grad_accum --gpu_count 0 --tpu_core_count 8 --val_batches 30 --val_every 30 --num_workers 0 --log_rate 1 + +export TPU_IP_ADDRESS=10.125.212.42 +export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470" +source /anaconda3/bin/activate torch-xla-nightly diff --git a/scripts/convert_bart_to_longformerencoderdecoder.py b/scripts/convert_bart_to_longformerencoderdecoder.py new file mode 100644 index 0000000..fc94996 --- /dev/null +++ b/scripts/convert_bart_to_longformerencoderdecoder.py @@ -0,0 +1,152 @@ +import argparse +import logging +import os + +from transformers import BartTokenizer + +from transformers import BartForConditionalGeneration +from transformers.modeling_bart import shift_tokens_right +from longformer.longformer_encoder_decoder import LongformerSelfAttentionForBart, LongformerEncoderDecoderConfig +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def create_long_model( + save_model_to, + base_model, + tokenizer_name_or_path, + attention_window, + max_pos +): + model = BartForConditionalGeneration.from_pretrained(base_model) + tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) + config = LongformerEncoderDecoderConfig.from_pretrained(base_model) + model.config = config + + # in BART attention_probs_dropout_prob is attention_dropout, but LongformerSelfAttention + # expects attention_probs_dropout_prob, so set it here + config.attention_probs_dropout_prob = config.attention_dropout + config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ] + + # extend position embeddings + tokenizer.model_max_length = max_pos + tokenizer.init_kwargs['model_max_length'] = max_pos + current_max_pos, embed_size = model.model.encoder.embed_positions.weight.shape + assert current_max_pos == config.max_position_embeddings + 2 + + config.max_encoder_position_embeddings = max_pos + config.max_decoder_position_embeddings = config.max_position_embeddings + del config.max_position_embeddings + max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 + assert max_pos >= current_max_pos + + # allocate a larger position embedding matrix for the encoder + new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) + # copy position embeddings over and over to initialize the new position embeddings + k = 2 + step = current_max_pos - 2 + while k < max_pos - 1: + new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] + k += step + model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed + + # allocate a larger position embedding matrix for the decoder + # new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size) + # # copy position embeddings over and over to initialize the new position embeddings + # k = 2 + # step = current_max_pos - 2 + # while k < max_pos - 1: + # new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:] + # k += step + # model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed + + # replace the `modeling_bart.SelfAttention` object with `LongformerSelfAttention` + config.attention_window = [attention_window] * config.num_hidden_layers + config.attention_dilation = [1] * config.num_hidden_layers + + for i, layer in enumerate(model.model.encoder.layers): + longformer_self_attn_for_bart = LongformerSelfAttentionForBart(config, layer_id=i) + + longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj + longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj + longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj + + longformer_self_attn_for_bart.longformer_self_attn.query_global = layer.self_attn.q_proj + longformer_self_attn_for_bart.longformer_self_attn.key_global = layer.self_attn.k_proj + longformer_self_attn_for_bart.longformer_self_attn.value_global = layer.self_attn.v_proj + + longformer_self_attn_for_bart.output = layer.self_attn.out_proj + + layer.self_attn = longformer_self_attn_for_bart + logger.info(f'saving model to {save_model_to}') + model.save_pretrained(save_model_to) + tokenizer.save_pretrained(save_model_to) + return model, tokenizer + + +def main(): + parser = argparse.ArgumentParser(description="Convert BART to LongBART. Replaces BART encoder's SelfAttnetion with LongformerSelfAttention") + parser.add_argument( + '--base_model', + type=str, + default='facebook/bart-large', + help='The name or path of the base model you want to convert' + ) + parser.add_argument( + '--tokenizer_name_or_path', + type=str, + default='facebook/bart-large', + help='The name or path of the tokenizer' + ) + parser.add_argument( + '--save_model_to', + type=str, + required=True, + help='The path to save the converted model' + ) + parser.add_argument( + '--attention_window', + type=int, + default=512, + help='attention window size for longformer self attention (one sided)' + ) + parser.add_argument( + '--max_pos', + type=int, + default=4096 * 4, + help='maximum encoder positions' + ) + + args = parser.parse_args() + + if not os.path.exists(args.save_model_to): + os.mkdir(args.save_model_to) + + create_long_model( + save_model_to=args.save_model_to, + base_model=args.base_model, + tokenizer_name_or_path=args.tokenizer_name_or_path, + attention_window=args.attention_window, + max_pos=args.max_pos + ) + + tokenizer = BartTokenizer.from_pretrained(args.save_model_to) + TXT = "My friends are but they eat too many carbs." + model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(args.save_model_to) + model.model.encoder.config.gradient_checkpointing = True + model.model.decoder.config.gradient_checkpointing = True + data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) + input_ids = data['input_ids'] + attention_mask = data['attention_mask'] + decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) + logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + probs = logits[0, masked_index].softmax(dim=0) + values, predictions = probs.topk(5) + print(tokenizer.convert_ids_to_tokens(predictions)) + + +if __name__ == "__main__": + main() diff --git a/scripts/convert_t5_to_longformerencoderdecoder.py b/scripts/convert_t5_to_longformerencoderdecoder.py new file mode 100644 index 0000000..c06930b --- /dev/null +++ b/scripts/convert_t5_to_longformerencoderdecoder.py @@ -0,0 +1,157 @@ +import argparse +import logging +import os + +from transformers import T5Tokenizer + +from transformers import T5ForConditionalGeneration +from transformers.modeling_bart import shift_tokens_right +from longformer.longformer_t5_encoder_decoder import LongformerSelfAttentionForT5, LongformerEncoderDecoderConfigT5 +from longformer.longformer_t5_encoder_decoder import LongformerEncoderDecoderForConditionalGenerationT5 + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def create_long_model( + save_model_to, + base_model, + tokenizer_name_or_path, + attention_window, + max_pos +): + model = T5ForConditionalGeneration.from_pretrained(base_model) + tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path, model_max_length=max_pos) + config = LongformerEncoderDecoderConfigT5.from_pretrained(base_model) + model.config = config + + # in T5 attention_probs_dropout_prob is dropout_rate, but LongformerSelfAttention + # expects attention_probs_dropout_prob, so set it here + config.attention_probs_dropout_prob = config.dropout_rate + config.architectures = ['LongformerEncoderDecoderForConditionalGenerationT5', ] + + # extend position embeddings + tokenizer.model_max_length = max_pos + tokenizer.init_kwargs['model_max_length'] = max_pos + # current_max_pos, embed_size = model.model.embed_positions.weight.shape + # assert current_max_pos == config.max_position_embeddings + 2 + + # config.max_encoder_position_embeddings = max_pos + # config.max_decoder_position_embeddings = config.max_position_embeddings + # del config.max_position_embeddings + # # TODO: check what's the deal with T5 here. + # max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2 + # assert max_pos >= current_max_pos + + # # allocate a larger position embedding matrix for the encoder + # new_encoder_pos_embed = model.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size) + # # copy position embeddings over and over to initialize the new position embeddings + # k = 2 + # step = current_max_pos - 2 + # while k < max_pos - 1: + # new_encoder_pos_embed[k:(k + step)] = model.model.encoder.embed_positions.weight[2:] + # k += step + # model.model.encoder.embed_positions.weight.data = new_encoder_pos_embed + + # allocate a larger position embedding matrix for the decoder + # new_decoder_pos_embed = model.model.decoder.embed_positions.weight.new_empty(max_pos, embed_size) + # # copy position embeddings over and over to initialize the new position embeddings + # k = 2 + # step = current_max_pos - 2 + # while k < max_pos - 1: + # new_decoder_pos_embed[k:(k + step)] = model.model.decoder.embed_positions.weight[2:] + # k += step + # model.model.decoder.embed_positions.weight.data = new_decoder_pos_embed + + # replace the `modeling_t5.T5Attention` object with `LongformerSelfAttention` + config.attention_window = [attention_window] * config.num_hidden_layers + config.attention_dilation = [1] * config.num_hidden_layers + # model.encoder.block = model.encoder.block[:1] + + for i, layer in enumerate(model.encoder.block): + self_attn = layer.layer[0].SelfAttention + + longformer_self_attn_for_t5 = LongformerSelfAttentionForT5(config, layer_id=i) + + longformer_self_attn_for_t5.longformer_self_attn.query = self_attn.q + longformer_self_attn_for_t5.longformer_self_attn.key = self_attn.k + longformer_self_attn_for_t5.longformer_self_attn.value = self_attn.v + + longformer_self_attn_for_t5.longformer_self_attn.query_global = self_attn.q + longformer_self_attn_for_t5.longformer_self_attn.key_global = self_attn.k + longformer_self_attn_for_t5.longformer_self_attn.value_global = self_attn.v + + longformer_self_attn_for_t5.output = self_attn.o + + layer.layer[0].SelfAttention = longformer_self_attn_for_t5 + + logger.info(f'saving model to {save_model_to}') + model.save_pretrained(save_model_to) + tokenizer.save_pretrained(save_model_to) + return model, tokenizer + + +def main(): + parser = argparse.ArgumentParser(description="Convert T5 to LongT5. Replaces T5 encoder's T5Attention with LongformerSelfAttention") + parser.add_argument( + '--base_model', + type=str, + default='t5-large', + help='The name or path of the base model you want to convert' + ) + parser.add_argument( + '--tokenizer_name_or_path', + type=str, + default='t5-large', + help='The name or path of the tokenizer' + ) + parser.add_argument( + '--save_model_to', + type=str, + required=True, + help='The path to save the converted model' + ) + parser.add_argument( + '--attention_window', + type=int, + default=512, + help='attention window size for longformer self attention (one sided)' + ) + parser.add_argument( + '--max_pos', + type=int, + default=4096 * 4, + help='maximum encoder positions' + ) + + args = parser.parse_args() + + if not os.path.exists(args.save_model_to): + os.mkdir(args.save_model_to) + + create_long_model( + save_model_to=args.save_model_to, + base_model=args.base_model, + tokenizer_name_or_path=args.tokenizer_name_or_path, + attention_window=args.attention_window, + max_pos=args.max_pos + ) + + tokenizer = T5Tokenizer.from_pretrained(args.save_model_to) + TXT = "My friends are but they eat too many carbs." + model = LongformerEncoderDecoderForConditionalGenerationT5.from_pretrained(args.save_model_to) + model.encoder.config.gradient_checkpointing = True + model.decoder.config.gradient_checkpointing = True + data = tokenizer([TXT], return_tensors='pt', padding='max_length', max_length=2048) + input_ids = data['input_ids'] + attention_mask = data['attention_mask'] + decoder_input_ids = shift_tokens_right(input_ids[:, :5], tokenizer.pad_token_id) + logits = model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False)[0] + masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + probs = logits[0, masked_index].softmax(dim=0) + values, predictions = probs.topk(5) + print(tokenizer.convert_ids_to_tokens(predictions)) + + +if __name__ == "__main__": + main() diff --git a/scripts/mem_profiler.py b/scripts/mem_profiler.py new file mode 100644 index 0000000..5d8e2f7 --- /dev/null +++ b/scripts/mem_profiler.py @@ -0,0 +1,69 @@ +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration +from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig + +from longformer.longformer import LongformerForMaskedLM +from longformer.longformer import LongformerConfig + +import torch +from torch.utils.data import DataLoader, Dataset +from pytorch_lightning import Trainer +import pytorch_lightning as pl + +seqlen = 1024 * 2 +global_size = seqlen // 100 +attention_window = 256 # one sided + + +class CoolDataset(Dataset): + def __len__(self): + return 1024 # number of examples + + def __getitem__(self, idx): + tokne_ids = torch.tensor([5] * seqlen) + mask = torch.tensor([1] * seqlen) + mask[:global_size] = 2 + return tokne_ids, mask + + +class MemoryProfiler(pl.LightningModule): + + def __init__(self, hparams=None): + super().__init__() + self.hparams = hparams + + config = LongformerEncoderDecoderConfig.from_pretrained('bart-long-4096') + # config = LongformerConfig.from_pretrained('roberta-large') + config.max_position_embeddings = seqlen + 2 + config.gradient_checkpointing = True + config.attention_mode = 'sliding_chunks' + # config.attention_mode = 'n2' + config.attention_window = [attention_window] * config.num_hidden_layers + config.attention_dilation = [1] * config.num_hidden_layers + self.model = LongformerEncoderDecoderForConditionalGeneration(config) + # self.model = LongformerForMaskedLM(config) + + def forward(self, x, y): + print(seqlen, global_size, attention_window, torch.cuda.max_memory_allocated(x.device) / 1024 ** 3) + # import ipdb; ipdb.set_trace() + # return self.model(x, attention_mask=y, decoder_input_ids=x[:, :attention_window * 2], use_cache=False) + return self.model(x, attention_mask=y) + + def training_step(self, batch, batch_idx): + # import ipdb; ipdb.set_trace() + x, y = batch + y_hat = self(x, y) + loss = y_hat[0].sum() + # import ipdb; ipdb.set_trace() + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(CoolDataset(), batch_size=2, num_workers=0) + + +if __name__ == '__main__': + model = MemoryProfiler(hparams={}) + trainer = Trainer(gpus=[0], progress_bar_refresh_rate=1, max_epochs=1, amp_level='O2', use_amp=True) + trainer.fit(model) diff --git a/scripts/pretrain.py b/scripts/pretrain.py new file mode 100644 index 0000000..8de5bbd --- /dev/null +++ b/scripts/pretrain.py @@ -0,0 +1,461 @@ +import argparse +import glob +import os +import random +import logging +import numpy as np +import math +from tqdm import tqdm +import time +import torch +from transformers import AutoTokenizer, AutoModelForMaskedLM +from transformers import DataCollatorForLanguageModeling +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +from torch.utils.data import Dataset, DataLoader +import pytorch_lightning as ptl +from pytorch_lightning.logging.test_tube import TestTubeLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateLogger + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# DONE: reproduce RoBERTa numbers on the Longformer corpus +# DONE: testing ddp single machine +# DONE: testing ddp multiple machines +# DONE: testing resume from checkpoint +# TODO: try on a TPU-pod +# TODO: run on beaker on ai2-server1/2 + + +try: + import torch_xla.core.xla_model as xm +except ImportError: + XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True + + +class MMapTextDataset(Dataset): + def __init__(self, mmap_filename, chunk_size, bos_token_id, eos_token_id): + # `chunk_size - 2` to reserve space for and + self.num_instances = np.memmap(mmap_filename, mode='r', dtype=np.uint16).shape[0] // (chunk_size - 2) + # defer loading the token_ids memmap until after the first __getitem__ call. + # when spawning new processes for ddp, there is a hard limit in python < 3.8 that + # pickle files need to be < 4GB. By waiting until after the first __getitem__ we + # don't have to pickle the memmap + self.token_ids = None + self._mmap_filename = mmap_filename + self._chunk_size = chunk_size + self._bos_token_id = bos_token_id + self._eos_token_id = eos_token_id + + def __len__(self): + return self.num_instances + + def __getitem__(self, i): + if self.token_ids is None: + self.token_ids = np.memmap(self._mmap_filename, mode='r', dtype=np.uint16) + from_index = i * (self._chunk_size - 2) + to_index = (i + 1) * (self._chunk_size - 2) + data = np.concatenate(([self._bos_token_id], self.token_ids[from_index:to_index], [self._eos_token_id])) + return torch.tensor(data, dtype=torch.long) + + # ========================= preprocessing code ========================= # + @staticmethod + def _process_file(full_fname): + "Step 1: tokenize an input text file then save token ids into `np.memmap` shards of size `args.shard_size`" + fname = full_fname.split('/')[-1] + log_filename = f'{args.input_dir}/logs-{args.shard_size}/{fname}.log' + if os.path.isfile(log_filename): + logging.info(f'Skipping {full_fname} ...') + return # log file already exists. Skip current file. + + logging.info(f'Processing {full_fname} ...') + with open(full_fname, 'r') as fin: + token_list = [] + shard_count = 0 + tokens_count = 0 + + def _write_shard(): + if len(token_list) == 0: + return + if token_list[-1] != MMapTextDataset.tokenizer.sep_token_id: # handle a rare case + token_list.append(MMapTextDataset.tokenizer.sep_token_id) + shared_filename = f'{args.input_dir}/shards-{args.shard_size}/{fname}-{shard_count}.bin' + logging.info(f'Writing {len(token_list)} tokens to shared {shared_filename}') + fp = np.memmap(shared_filename, dtype=np.uint16, mode='w+', shape=len(token_list)) + fp[:] = token_list[:] + del fp # flush and close file + for line in tqdm(fin): + line = line.strip() + if line == '': # drop empty lines + continue + tokens = MMapTextDataset.tokenizer.encode(line, add_special_tokens=False) # `__getitem__` adds special tokens + token_list.extend(tokens) + if len(token_list) > args.shard_size: + _write_shard() + tokens_count += len(token_list) + token_list = [] + shard_count += 1 + else: + token_list.append(MMapTextDataset.tokenizer.sep_token_id) + _write_shard() + tokens_count += len(token_list) + with open(log_filename, 'w') as f: + f.write(f'Generated {tokens_count} tokens in {shard_count + 1} shards') + + @staticmethod + def _combine_shards(output_fname, shards_list): + "Step 2: combining memmap shards into one `train.bin` or `val.bin` file" + total_size = 0 + for filename in shards_list: + total_size += np.memmap(filename, mode='r', dtype=np.uint16).shape[0] + logging.info(f'Writing {total_size} tokens to {output_fname}') + all_token_ids = np.empty(total_size, dtype=np.uint16) + last_token_index = 0 + for filename in tqdm(shards_list): + shared = np.memmap(filename, mode='r', dtype=np.uint16) + all_token_ids[last_token_index:last_token_index+len(shared)] = shared[:] + last_token_index += len(shared) + fp = np.memmap(output_fname, dtype=np.uint16, mode='w+', shape=total_size) + fp[:] = all_token_ids[:] + del fp + + @staticmethod + def raw_text_to_mmap(args): + """This is the main preprocessing function. It processes all the text files in `args.input_dir` and + outputs two np.memmap files, one for training and one for validation with ratio `args.train_dev_split`. + Processing each input file involves tokenizing it, sharding it into shards of size `args.shard_size`, + then writing each shard as an np.memmap file. The stream of tokens in the memmap file represents documents + separated with `tokenizer.sep_token`. In `__getitem__`, the `tokenizer.bos_token` and `tokenizer.eos_token` + are added. The reason for not adding them at preprocessing time is to allow different sequence lengths + later on. Notice that this is the "FULL-SENTENCES" setting in the RoBERTa paper, Table2. + """ + MMapTextDataset.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True) + assert len(MMapTextDataset.tokenizer) < 65535 # will use uint16 to store token ids + all_files = glob.glob(f'{args.input_dir}/*.txt') + + if os.path.exists(f'{args.input_dir}/cache/train.bin') and os.path.exists(f'{args.input_dir}/cache/val.bin'): + logger.info("Cache already exists. Remove the cache directory to regenerate") + return + try: + os.mkdir(f'{args.input_dir}/cache/') + except FileExistsError: + pass + try: + os.mkdir(f'{args.input_dir}/shards-{args.shard_size}/') + except FileExistsError: + pass + try: + os.mkdir(f'{args.input_dir}/logs-{args.shard_size}/') # log progrss to be able to resume + except FileExistsError: + pass + + # STEP1: tokenizing and saving to shards + if args.num_preprocessing_workers > 1: + from multiprocessing.pool import Pool + with Pool(args.num_preprocessing_workers) as p: + list(tqdm(p.imap(MMapTextDataset._process_file, all_files), total=len(all_files))) + else: + [MMapTextDataset._process_file(f) for f in tqdm(all_files)] + + # STEP2: shuffling shards and combining them into train.bin and val.bin files + all_shards = glob.glob(f'{args.input_dir}/shards-{args.shard_size}/*.bin') + random.shuffle(all_shards) # shuffling based on shards not individual lines + val_shards_count = int(args.train_dev_split * len(all_shards)) + val_shards = all_shards[:val_shards_count] + train_shards = all_shards[val_shards_count:] + # TODO: if MMapTextDataset._combining_shards is very slow for large files, it can be skipped but we nned to + # update the dataset to read from multiple shards directly + MMapTextDataset._combine_shards(f'{args.input_dir}/cache/val.bin', val_shards) + MMapTextDataset._combine_shards(f'{args.input_dir}/cache/train.bin', train_shards) + + del MMapTextDataset.tokenizer + # ========================= end preprocessing code ========================= # + + +class Pretrainer(ptl.LightningModule): + + def __init__(self, hparams): + super().__init__() + + self.args = hparams + self.hparams = self.args + + self.model = AutoModelForMaskedLM.from_pretrained(args.model) + self.config = self.model.config + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + self.bos_token_id = tokenizer.bos_token_id + + logger.info(f'Creating dataset cache from dir {self.args.input_dir}. This could be slow the first time.') + MMapTextDataset.raw_text_to_mmap(args) + + # TODO: add support for other objective functions (whole word masking, BART objectives) + self.data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=True, mlm_probability=self.args.mlm_prob + ) + self.start_time = 0 + + def to(self, *args, **kwargs): + param_count_before_to = len(list(self.parameters())) + super().to(*args, **kwargs) + if self.trainer.use_tpu: + # need to re-tie the weights after moving to XLA! + self.model.tie_weights() + if 'roberta' in self.args.model: + self.model.lm_head.bias = self.model.lm_head.decoder.bias + param_count_after_to = len(list(self.parameters())) + assert param_count_before_to == param_count_after_to + + def forward(self, input_ids=None, labels=None): + # get the padding mask - 1 for NOT masked, 0 for MASKED/PAD + attention_mask = (input_ids != self.pad_token_id).int() + + # output is loss, prediction_scores, hidden_states + output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + return output[0] # loss + + def training_step(self, batch, batch_nb): + loss = self(**batch) + input_ids = batch['input_ids'] + tensorboard_logs = { + 'input_size': input_ids.numel(), + 'mlm_loss': loss, + 'mlm_bpc': loss/math.log(2), + 'mlm_perplexity': torch.exp(loss), + 'token_per_step': input_ids.numel() * self.args.grad_accum * self.trainer.world_size, + } + if self.start_time != 0: + elapsed_time = time.time() - self.start_time + tensorboard_logs['second_per_batch'] = elapsed_time + self.start_time = time.time() + if self.on_gpu: + tensorboard_logs['memory'] = torch.cuda.memory_allocated(loss.device) / 1024 ** 3 + + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_nb): + # TODO: log how long evaluation takes + self.start_time = 0 # reset training_step timer + loss = self(**batch) + tensorboard_logs = { + 'val_mlm_loss': loss.detach(), + } + return {'val_loss': tensorboard_logs["val_mlm_loss"], 'log': tensorboard_logs} + + def validation_epoch_end(self, outputs): + avg_loss = torch.stack([x['log']['val_mlm_loss'] for x in outputs if 'val_mlm_loss' in x['log']]).mean() + if self.use_ddp: + # TODO: PTL is already doing this. Is it still needed here? + # https://github.com/PyTorchLightning/pytorch-lightning/blob/0.8.5/pytorch_lightning/metrics/converters.py#L251 + torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM) + avg_loss /= torch.distributed.get_world_size() + elif self.use_tpu: + avg_loss = xm.all_reduce(xm.REDUCE_SUM, avg_loss) / xm.xrt_world_size() + + logs = {'val_mlm_loss': avg_loss} + return {'log': logs, 'progress_bar': logs, "val_loss": avg_loss} + + def configure_optimizers(self): + no_decay = ["bias", "LayerNorm.weight"] + + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=self.args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.train_steps + ) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + def _get_loader(self, fname, is_train): + dataset = MMapTextDataset(fname, chunk_size=self.args.seqlen, + bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id) + + # TODO: consider `replace_sampler_ddp=True` and removing the following if statement + if self.trainer.use_ddp: + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) + shuffle = False + elif self.trainer.use_tpu: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=is_train, + ) + shuffle = False + else: + sampler = None + shuffle = is_train + + loader = DataLoader( + dataset, + batch_size=self.args.batch_size, + shuffle=shuffle, + sampler=sampler, + num_workers=self.args.num_workers, + collate_fn=self.data_collator, + drop_last=is_train, + ) + return loader + + def train_dataloader(self): + return self._get_loader(f'{self.args.input_dir}/cache/train.bin', True) + + def val_dataloader(self): + return self._get_loader(f'{self.args.input_dir}/cache/val.bin', False) + + def grad_norm(self, norm_type): + # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params + # TODO: grad_norm reporting needs to take fp16 loss scale into account + parameters = [p for p in self.parameters() if p.grad is not None] + device = parameters[0].device + total_norm = torch.zeros([], device=device if parameters else None) + norm_type = float(norm_type) + for p in parameters: + param_norm = p.grad.data.pow(norm_type).sum() + total_norm.add_(param_norm) + total_norm = (total_norm ** (1.0 / norm_type)) + return {'total_grad_norm': total_norm} + + @staticmethod + def add_args(parser): + parser.add_argument("--seed", type=int, default=3) + + # Dataset. Some of these params are only useful when generating the dataset cache + parser.add_argument("--input_dir", type=str, default='/net/nfs.corp/s2-research/beltagy/longformer/data/') + # Used only at the preprocessing phase + parser.add_argument("--train_dev_split", type=float, default=0.05) + parser.add_argument("--shard_size", type=int, default=1024 ** 3 // 4) # 250MB + parser.add_argument("--num_preprocessing_workers", type=int, default=1) + # Used only at the training phase + parser.add_argument("--seqlen", type=int, default=512) + parser.add_argument("--mlm_prob", type=float, default=0.15) + + # HF model loading + parser.add_argument("--tokenizer", type=str, default='roberta-base') + parser.add_argument("--model", type=str, default='roberta-base') + + # Checkpointing and logging + parser.add_argument("--save_dir", type=str, default='/runs/') + parser.add_argument("--save_prefix", type=str, default='test', + help="path of output directory is --save_dir/--save_prefix") + parser.add_argument("--resume", type=str, default=None, # It is better to use a different output dir. + help="Path to a checkpoint to load model weights and training state. It overwrites args") + parser.add_argument("--resume_model_only", type=str, default=None, + help="Path to a checkpoint to load model weights but not training state") + parser.add_argument("--log_rate", type=int, default=10) + parser.add_argument("--disable_checkpointing", type=bool, default=False) + + # Training hyperparams + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--train_steps", type=int, default=3000, help='# training grad. updates') + parser.add_argument("--warmup_steps", type=int, default=1000, help='# warmup grad. updates') + parser.add_argument("--val_every", type=int, default=1000, help='# training grad. updates between evaluations') + parser.add_argument("--val_batches", type=int, default=1000, help='# evaluation **batches**') + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument("--adam_epsilon", type=float, default=1e-6) + parser.add_argument("--grad_clip", type=float, default=0) # TODO: test this with fp16. Likely not working + + # RoBERTa's tokens_per_step = 2^18 = 512(seqlen) x 1(gpu_count) x 32(batch_size) x 16(grad_accum) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--grad_accum", type=int, default=1) + + # Compute resources + parser.add_argument("--fp16", type=bool, default=False) + parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--gpu_count", type=int, default=1, # `--gpus` is reserved for internal use by PTL + help="Number of gpus. This respects `CUDA_VISIBLE_DEVICES`") + + # For multi-node training, use the PyTorch launch script. The script and instructions can be found here: + # https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py. + # To run PTL in a mode compatible with the launch script, two things are needed: + # - pass the argument `--use_env` to `torch.distributed.launch` + # - make sure `--nproc_per_node` matches `--gpu_count` and `--nnodes` matches `--node_count`. + # For example, to run on 2 nodes, 3 gpus each, the command line on node rank 1 would be like: + # >>>> python -m torch.distributed.launch \ + # --use_env --nnodes 2 --nproc_per_node 3 \ + # --node_rank 1 --master_addr s2-server4 --master_port 12343 \ + # scripts/pretrain.py \ + # --gpu_count 2 --node_count 2 \ + # --input_dir my_data_dir --save_prefix test_multinode + parser.add_argument("--node_count", type=int, default=1, + help="Number of nodes. It needs to match --nnodes of torch.distributed.launch") + parser.add_argument("--tpu_core_count", type=int, default=None) + + return parser + + +def main(args): + random.seed(args.seed * 10) + np.random.seed(args.seed * 100) + torch.manual_seed(args.seed * 1000) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed * 10000) + + if args.resume_model_only is not None: + pretrainer = Pretrainer.load_from_checkpoint(args.resume_model_only, args) + else: + pretrainer = Pretrainer(args) + + # logger here is a SummaryWritter for tensorboard + # it is used by the trainer, and certain return variables + # from the model are automatically logged + logger = TestTubeLogger( + save_dir=args.save_dir, + name=args.save_prefix, + version=0 # always use version=0 + ) + + checkpoint_callback = ModelCheckpoint( + # model saved to filepath/prefix_.... + filepath=os.path.join(args.save_dir, args.save_prefix, 'checkpoint'), + prefix='', + save_top_k=1, + save_last=True, + verbose=True, + monitor='val_loss', + mode='min', + period=-1, # to allow multiple checkpoints per epoch + ) + + args.val_every *= args.grad_accum # PTL is expecting number of batches_per_gpu + trainer = ptl.Trainer( + gpus=args.gpu_count, + num_nodes=args.node_count, + num_tpu_cores=args.tpu_core_count, + distributed_backend='ddp' if (args.gpu_count > 1 or args.node_count > 1) else None, + replace_sampler_ddp=False, + track_grad_norm=2, + max_epochs=10000, min_epochs=0, max_steps=args.train_steps, # run for many epochs, but stop after max_steps + val_check_interval=args.val_every, limit_val_batches=args.val_batches, + early_stop_callback=None, + row_log_interval=args.log_rate, + progress_bar_refresh_rate=args.log_rate, + logger=logger, + checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else None, + accumulate_grad_batches=args.grad_accum, + resume_from_checkpoint=args.resume, + gradient_clip_val=args.grad_clip, + precision=16 if args.fp16 else 32, amp_level='O2', + num_sanity_val_steps=2, + callbacks=[LearningRateLogger()], + ) + trainer.fit(pretrainer) + + +if __name__ == "__main__": + parser = Pretrainer.add_args(argparse.ArgumentParser(description="pretrain")) + args = parser.parse_args() + main(args) diff --git a/scripts/summarization.py b/scripts/summarization.py new file mode 100644 index 0000000..8374022 --- /dev/null +++ b/scripts/summarization.py @@ -0,0 +1,348 @@ +import os +import argparse +import random +import numpy as np + +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig +from transformers.optimization import get_linear_schedule_with_warmup, Adafactor +import nlp +from rouge_score import rouge_scorer + +import pytorch_lightning as pl +from pytorch_lightning.logging import TestTubeLogger +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel + + +from longformer import LongformerEncoderDecoderForConditionalGeneration, LongformerEncoderDecoderConfig +from longformer.sliding_chunks import pad_to_window_size + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): + """From fairseq""" + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + count = (~pad_mask).sum() + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + count = nll_loss.numel() + + nll_loss = nll_loss.sum() / count + smooth_loss = smooth_loss.sum() / count + eps_i = epsilon / lprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +class SummarizationDataset(Dataset): + def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len): + self.hf_dataset = hf_dataset + self.tokenizer = tokenizer + self.max_input_len = max_input_len + self.max_output_len = max_output_len + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + entry = self.hf_dataset[idx] + input_ids = self.tokenizer.encode(entry['article'], truncation=True, max_length=self.max_input_len) + output_ids = self.tokenizer.encode(entry['abstract'], truncation=True, max_length=self.max_output_len) + if self.tokenizer.bos_token_id is None: # pegasus + output_ids = [self.tokenizer.pad_token_id] + output_ids + return torch.tensor(input_ids), torch.tensor(output_ids) + + @staticmethod + def collate_fn(batch): + # A hack to know if this is bart or pegasus. DDP doesn't like global variables nor class-level memebr variables + if batch[0][0][-1].item() == 2: + pad_token_id = 1 # AutoTokenizer.from_pretrained('facebook/bart-base').pad_token_id + elif batch[0][0][-1].item() == 1: + pad_token_id = 0 # AutoTokenizer.from_pretrained('google/pegasus-large').pad_token_id + else: + assert False + + input_ids, output_ids = list(zip(*batch)) + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id) + output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id) + return input_ids, output_ids + + +class Summarizer(pl.LightningModule): + + def __init__(self, args): + super().__init__() + self.args = args + self.hparams = args + self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer, use_fast=True) + + if 'long' in self.args.model_path: + config = LongformerEncoderDecoderConfig.from_pretrained(self.args.model_path) + config.attention_dropout = self.args.attention_dropout + config.gradient_checkpointing = self.args.grad_ckpt + config.attention_mode = self.args.attention_mode + config.attention_window = [self.args.attention_window] * config.encoder_layers + self.model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained( + self.args.model_path, config=config) + else: + config = AutoConfig.from_pretrained(self.args.model_path) + config.attention_dropout = self.args.attention_dropout + self.model = AutoModelForSeq2SeqLM.from_pretrained( + self.args.model_path, config=config) + self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None + + def _prepare_input(self, input_ids): + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) + attention_mask[input_ids == self.tokenizer.pad_token_id] = 0 + if isinstance(self.model, LongformerEncoderDecoderForConditionalGeneration): + attention_mask[:, 0] = 2 # global attention on one token for all model params to be used, which is important for gradient checkpointing to work + if self.args.attention_mode == 'sliding_chunks': + half_padding_mod = self.model.config.attention_window[0] + elif self.args.attention_mode == 'sliding_chunks_no_overlap': + half_padding_mod = self.model.config.attention_window[0] / 2 + else: + raise NotImplementedError + input_ids, attention_mask = pad_to_window_size( # ideally, should be moved inside the LongformerModel + input_ids, attention_mask, half_padding_mod, self.tokenizer.pad_token_id) + return input_ids, attention_mask + + def forward(self, input_ids, output_ids): + input_ids, attention_mask = self._prepare_input(input_ids) + decoder_input_ids = output_ids[:, :-1] + decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id) + labels = output_ids[:, 1:].clone() + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + use_cache=False,) + lm_logits = outputs[0] + if self.args.label_smoothing == 0: + # Same behavior as modeling_bart.py, besides ignoring pad_token_id + ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) + assert lm_logits.shape[-1] == self.model.config.vocab_size + loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) + else: + lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, labels, self.args.label_smoothing, ignore_index=self.tokenizer.pad_token_id + ) + return [loss] + + def training_step(self, batch, batch_nb): + output = self.forward(*batch) + loss = output[0] + lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr'] + tensorboard_logs = {'train_loss': loss, 'lr': lr, + 'input_size': batch[0].numel(), + 'output_size': batch[1].numel(), + 'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0} + return {'loss': loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_nb): + for p in self.model.parameters(): + p.requires_grad = False + + outputs = self.forward(*batch) + vloss = outputs[0] + input_ids, output_ids = batch + input_ids, attention_mask = self._prepare_input(input_ids) + generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, + use_cache=True, max_length=self.args.max_output_len, + num_beams=1) + generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True) + gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True) + scorer = rouge_scorer.RougeScorer(rouge_types=['rouge1', 'rouge2', 'rougeL', 'rougeLsum'], use_stemmer=False) + rouge1 = rouge2 = rougel = rougelsum = 0.0 + for ref, pred in zip(gold_str, generated_str): + score = scorer.score(ref, pred) + rouge1 += score['rouge1'].fmeasure + rouge2 += score['rouge2'].fmeasure + rougel += score['rougeL'].fmeasure + rougelsum += score['rougeLsum'].fmeasure + rouge1 /= len(generated_str) + rouge2 /= len(generated_str) + rougel /= len(generated_str) + rougelsum /= len(generated_str) + + return {'vloss': vloss, + 'rouge1': vloss.new_zeros(1) + rouge1, + 'rouge2': vloss.new_zeros(1) + rouge2, + 'rougeL': vloss.new_zeros(1) + rougel, + 'rougeLsum': vloss.new_zeros(1) + rougelsum, } + + def validation_epoch_end(self, outputs): + for p in self.model.parameters(): + p.requires_grad = True + + names = ['vloss', 'rouge1', 'rouge2', 'rougeL', 'rougeLsum'] + metrics = [] + for name in names: + metric = torch.stack([x[name] for x in outputs]).mean() + if self.trainer.use_ddp: + torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) + metric /= self.trainer.world_size + metrics.append(metric) + logs = dict(zip(*[names, metrics])) + print(logs) + return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs} + + def test_step(self, batch, batch_nb): + return self.validation_step(batch, batch_nb) + + def test_epoch_end(self, outputs): + result = self.validation_epoch_end(outputs) + print(result) + + def configure_optimizers(self): + if self.args.adafactor: + optimizer = Adafactor(self.model.parameters(), lr=self.args.lr, scale_parameter=False, relative_step=False) + else: + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr) + if self.args.debug: + return optimizer # const LR + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 + num_steps = self.args.dataset_size * self.args.epochs / num_gpus / self.args.grad_accum / self.args.batch_size + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=self.args.warmup, num_training_steps=num_steps + ) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + def _get_dataloader(self, current_dataloader, split_name, is_train): + if current_dataloader is not None: + return current_dataloader + dataset = SummarizationDataset(hf_dataset=self.hf_datasets[split_name], tokenizer=self.tokenizer, + max_input_len=self.args.max_input_len, max_output_len=self.args.max_output_len) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None + return DataLoader(dataset, batch_size=self.args.batch_size, shuffle=(sampler is None), + num_workers=self.args.num_workers, sampler=sampler, + collate_fn=SummarizationDataset.collate_fn) + + @pl.data_loader + def train_dataloader(self): + self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True) + return self.train_dataloader_object + + @pl.data_loader + def val_dataloader(self): + self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, 'validation', is_train=False) + return self.val_dataloader_object + + @pl.data_loader + def test_dataloader(self): + self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False) + return self.test_dataloader_object + + def configure_ddp(self, model, device_ids): + model = LightningDistributedDataParallel( + model, + device_ids=device_ids, + find_unused_parameters=False + ) + return model + + @staticmethod + def add_model_specific_args(parser, root_dir): + parser.add_argument("--save_dir", type=str, default='summarization') + parser.add_argument("--save_prefix", type=str, default='test') + parser.add_argument("--batch_size", type=int, default=16, help="Batch size") + parser.add_argument("--grad_accum", type=int, default=1, help="number of gradient accumulation steps") + parser.add_argument("--gpus", type=int, default=-1, + help="Number of gpus. 0 for CPU") + parser.add_argument("--warmup", type=int, default=1000, help="Number of warmup steps") + parser.add_argument("--lr", type=float, default=0.00003, help="Maximum learning rate") + parser.add_argument("--val_every", type=float, default=1.0, help="Number of training steps between validations") + parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used') + parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers") + parser.add_argument("--seed", type=int, default=1234, help="Seed") + parser.add_argument("--epochs", type=int, default=5, help="Number of epochs") + parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing") + parser.add_argument("--max_output_len", type=int, default=256, + help="maximum num of wordpieces/summary. Used for training and testing") + parser.add_argument("--max_input_len", type=int, default=512, + help="maximum num of wordpieces/summary. Used for training and testing") + parser.add_argument("--test", action='store_true', help="Test only, no training") + parser.add_argument("--model_path", type=str, default='facebook/bart-base', + help="Path to the checkpoint directory or model name") + parser.add_argument("--tokenizer", type=str, default='facebook/bart-base') + parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing") + parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32") + parser.add_argument("--debug", action='store_true', help="debug run") + parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from") + parser.add_argument('--grad_ckpt', action='store_true', help='Enable gradient checkpointing to save memory') + parser.add_argument("--attention_dropout", type=float, default=0.1, help="attention dropout") + parser.add_argument("--attention_mode", type=str, default='sliding_chunks', help="Longformer attention mode") + parser.add_argument("--attention_window", type=int, default=512, help="Attention window") + parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) + parser.add_argument("--adafactor", action='store_true', help="Use adafactor optimizer") + + return parser + + +def main(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + model = Summarizer(args) + model.hf_datasets = nlp.load_dataset('scientific_papers', 'arxiv') + + logger = TestTubeLogger( + save_dir=args.save_dir, + name=args.save_prefix, + version=0 # always use version=0 + ) + + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"), + save_top_k=5, + verbose=True, + monitor='avg_val_loss', + mode='min', + period=-1, + prefix='' + ) + + print(args) + + args.dataset_size = 203037 # hardcode dataset size. Needed to compute number of steps for the lr scheduler + + trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if torch.cuda.is_available() else None, + track_grad_norm=-1, + max_epochs=args.epochs if not args.debug else 100, + max_steps=None if not args.debug else 1, + replace_sampler_ddp=False, + accumulate_grad_batches=args.grad_accum, + val_check_interval=args.val_every if not args.debug else 1, + num_sanity_val_steps=2 if not args.debug else 0, + check_val_every_n_epoch=1 if not args.debug else 1, + val_percent_check=args.val_percent_check, + test_percent_check=args.val_percent_check, + logger=logger, + checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False, + show_progress_bar=not args.no_progress_bar, + use_amp=not args.fp32, amp_level='O2', + resume_from_checkpoint=args.resume_ckpt, + ) + if not args.test: + trainer.fit(model) + trainer.test(model) + + +if __name__ == "__main__": + main_arg_parser = argparse.ArgumentParser(description="summarization") + parser = Summarizer.add_model_specific_args(main_arg_parser, os.getcwd()) + args = parser.parse_args() + main(args) diff --git a/scripts/test_tpu.py b/scripts/test_tpu.py new file mode 100644 index 0000000..e692890 --- /dev/null +++ b/scripts/test_tpu.py @@ -0,0 +1,44 @@ +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import AutoModel +import pytorch_lightning as pl + + +class CoolDataset(Dataset): + + def __len__(self): + return 128 * 128 + + def __getitem__(self, idx): + return torch.tensor([1, 2, 3, 4] * 128 * 8), torch.tensor([1, 1, 1, 1] * 128 * 8) + + +class CoolSystem(pl.LightningModule): + + def __init__(self): + super().__init__() + + self.model = AutoModel.from_pretrained('allenai/longformer-base-4096') + # self.model = AutoModel.from_pretrained('roberta-base') + + def forward(self, x, y): + return self.model(x, attention_mask=None) + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x, y) + loss = y_hat[0].sum() + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + loader = DataLoader(CoolDataset(), batch_size=1, num_workers=0) + return loader + + +if __name__ == '__main__': + model = CoolSystem() + trainer = pl.Trainer(num_tpu_cores=8, progress_bar_refresh_rate=1, max_epochs=10, num_sanity_val_steps=0, gpus=0) + trainer.fit(model) diff --git a/scripts/triviaqa.py b/scripts/triviaqa.py index 281c297..967f97a 100644 --- a/scripts/triviaqa.py +++ b/scripts/triviaqa.py @@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, Dataset -from transformers import RobertaTokenizer +from transformers import RobertaTokenizer, AutoModel, AutoConfig, AutoModelWithLMHead from scripts.triviaqa_utils import evaluation_utils import pytorch_lightning as pl @@ -110,11 +110,13 @@ def is_whitespace(c): try: start_position = char_to_word_offset[answer_offset] end_position = char_to_word_offset[answer_offset + answer_length - 1] - except: + token_ids = self.tokenizer.encode(orig_answer_text) + except RuntimeError: print(f'Reading example {idx} failed') start_position = 0 end_position = 0 - answer_spans.append({'start': start_position, 'end': end_position}) + answer_spans.append({'start': start_position, 'end': end_position, + 'text': orig_answer_text, 'token_ids': token_ids}) # ===== Given an example, convert it into tensors ============= query_tokens = self.tokenizer.tokenize(question_text) @@ -146,6 +148,7 @@ def is_whitespace(c): segment_ids_list = [] start_positions_list = [] end_positions_list = [] + answer_token_ids_list = [] for slice_start in range(0, len(all_doc_tokens), max_tokens_per_doc_slice - self.doc_stride): slice_end = min(slice_start + max_tokens_per_doc_slice, len(all_doc_tokens)) @@ -172,6 +175,7 @@ def is_whitespace(c): doc_offset = len(query_tokens) + 2 - slice_start start_positions = [] end_positions = [] + answer_token_ids = [] for answer_span in answer_spans: start_position = answer_span['start'] end_position = answer_span['end'] @@ -183,6 +187,7 @@ def is_whitespace(c): continue start_positions.append(tok_start_position_in_doc + doc_offset) end_positions.append(tok_end_position_in_doc + doc_offset) + answer_token_ids.append(answer_span['token_ids']) assert len(start_positions) == len(end_positions) if self.ignore_seq_with_no_answers and len(start_positions) == 0: continue @@ -190,32 +195,58 @@ def is_whitespace(c): # answers from start_positions and end_positions if > self.max_num_answers start_positions = start_positions[:self.max_num_answers] end_positions = end_positions[:self.max_num_answers] + answer_token_ids = answer_token_ids[:self.max_num_answers] # -1 padding up to self.max_num_answers padding_len = self.max_num_answers - len(start_positions) start_positions.extend([-1] * padding_len) end_positions.extend([-1] * padding_len) + answer_token_ids.extend([[]] * padding_len) # replace duplicate start/end positions with `-1` because duplicates can result into -ve loss values found_start_positions = set() found_end_positions = set() - for i, (start_position, end_position) in enumerate(zip(start_positions, end_positions)): + found_answer_token_ids = set() + for i, (start_position, end_position, answer_tokens) in enumerate( + zip(start_positions, end_positions, answer_token_ids) + ): if start_position in found_start_positions: start_positions[i] = -1 if end_position in found_end_positions: end_positions[i] = -1 + answer_tokens_as_str = ','.join([str(x) for x in answer_tokens]) + if answer_tokens_as_str in found_answer_token_ids: + answer_token_ids[i] = [] found_start_positions.add(start_position) found_end_positions.add(end_position) + found_answer_token_ids.add(answer_tokens_as_str) input_ids_list.append(input_ids) input_mask_list.append(input_mask) segment_ids_list.append(segment_ids) start_positions_list.append(start_positions) end_positions_list.append(end_positions) + answer_token_ids_list.append(answer_token_ids) + + # pad answers in answer_token_ids_list to the longest answer + max_answer_len = max([len(item) for sublist in answer_token_ids_list for item in sublist]) # flat list + if max_answer_len == 0: + max_answer_len = 2 + for answers_of_one_slice in answer_token_ids_list: + for answer_tokens in answers_of_one_slice: + if len(answer_tokens) == 0: + # TODO: or ? + padding_len = max_answer_len - len(answer_tokens) - 2 + answer_tokens.extend([self.tokenizer.bos_token_id, self.tokenizer.eos_token_id] + + ([self.tokenizer.pad_token_id] * padding_len)) + else: + padding_len = max_answer_len - len(answer_tokens) + answer_tokens.extend([self.tokenizer.pad_token_id] * padding_len) tensors_list.append((torch.tensor(input_ids_list), torch.tensor(input_mask_list), torch.tensor(segment_ids_list), torch.tensor(start_positions_list), torch.tensor(end_positions_list), + torch.tensor(answer_token_ids_list), self._get_qid(qa['id']), qa["aliases"])) # for eval return tensors_list @@ -259,14 +290,39 @@ def __init__(self, args): self.tokenizer.model_max_length = self.args.max_seq_len self.model = self.load_model() self.num_labels = 2 - self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels) + if not self.args.seq2seq: + self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels) self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None def load_model(self): - model = Longformer.from_pretrained(self.args.model_path) - for layer in model.encoder.layer: - layer.attention.self.attention_mode = self.args.attention_mode - self.args.attention_window = layer.attention.self.attention_window + if 'longformer' in self.args.model_path: + model = Longformer.from_pretrained(self.args.model_path) + for layer in model.encoder.layer: + layer.attention.self.attention_mode = self.args.attention_mode + self.args.attention_window = layer.attention.self.attention_window + elif self.args.model_path in ['bart.large', 'bart.base']: + model = torch.hub.load('pytorch/fairseq', self.args.model_path) + model.config = model.args + model.config.hidden_size = model.config.decoder_output_dim + elif 'bart' in self.args.model_path and 'base' in self.args.model_path: + config = AutoConfig.from_pretrained(self.args.model_path) + config.encoder_attention_heads = 12 + config.decoder_attention_heads = 12 + config.attention_dropout = 0.1 + if self.args.seq2seq: + model = AutoModelWithLMHead.from_pretrained(self.args.model_path, config=config) + else: + model = AutoModel.from_pretrained(self.args.model_path, config=config) + elif 'bart' in self.args.model_path and 'large' in self.args.model_path: + config = AutoConfig.from_pretrained(self.args.model_path) + config.attention_dropout = 0.1 + config.gradient_checkpointing = True + if self.args.seq2seq: + model = AutoModelWithLMHead.from_pretrained(self.args.model_path, config=config) + else: + model = AutoModel.from_pretrained(self.args.model_path, config=config) + else: + model = AutoModel.from_pretrained(self.args.model_path) print("Loaded model with config:") print(model.config) @@ -276,30 +332,51 @@ def load_model(self): model.train() return model - def forward(self, input_ids, attention_mask, segment_ids, start_positions, end_positions): - question_end_index = self._get_question_end_index(input_ids) - # Each batch is one document, and each row of the batch is a chunck of the document. - # Make sure all rows have the same question length. - assert (question_end_index[0].float() == question_end_index.float().mean()).item() - - # local attention everywhere - attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) - # global attention for the question tokens - attention_mask[:, :question_end_index.item()] = 2 - - # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size - input_ids, attention_mask = pad_to_window_size( - input_ids, attention_mask, self.args.attention_window, self.tokenizer.pad_token_id) - - sequence_output = self.model( - input_ids, - attention_mask=attention_mask)[0] - - # The pretrained TriviaQA model wasn't trained with padding, so remove padding tokens - # before computing loss and decoding. - padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum() - if padding_len > 0: - sequence_output = sequence_output[:, :-padding_len] + def forward(self, input_ids, attention_mask, segment_ids, start_positions, end_positions, answer_token_ids): + if 'longformer' in self.args.model_path: + question_end_index = self._get_question_end_index(input_ids) + # Each batch is one document, and each row of the batch is a chunck of the document. + # Make sure all rows have the same question length. + assert (question_end_index[0].float() == question_end_index.float().mean()).item() + + # local attention everywhere + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) + # global attention for the question tokens + attention_mask[:, :question_end_index.item()] = 2 + + # sliding_chunks implemenation of selfattention requires that seqlen is multiple of window size + input_ids, attention_mask = pad_to_window_size( + input_ids, attention_mask, self.args.attention_window, self.tokenizer.pad_token_id) + + sequence_output = self.model( + input_ids, + attention_mask=attention_mask)[0] + + # The pretrained TriviaQA model wasn't trained with padding, so remove padding tokens + # before computing loss and decoding. + padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum() + if padding_len > 0: + sequence_output = sequence_output[:, :-padding_len] + elif self.args.model_path in ['bart.large', 'bart.base']: + sequence_output = self.model.extract_features(input_ids) + else: + if self.args.seq2seq: + decoder_input_ids = answer_token_ids[:, 0, :-1].clone() + decoder_input_ids[decoder_input_ids == self.tokenizer.eos_token_id] = self.tokenizer.pad_token_id + decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id) + labels = answer_token_ids[:, 0, 1:].contiguous() + labels[answer_token_ids[:, 0, 1:] == self.tokenizer.pad_token_id] = -100 + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=labels) + loss = outputs[0] + logit_scores = outputs[1].softmax(dim=2)[:, :, 0].sum(dim=1) + return [loss, logit_scores] + else: + sequence_output = self.model(input_ids, attention_mask=attention_mask)[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -368,8 +445,8 @@ def or_softmax_cross_entropy_loss_one_doc(self, logits, target, ignore_index=-1, return loss[~torch.isinf(loss)].sum() def training_step(self, batch, batch_nb): - input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch - output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends) + input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids, qids, aliases = batch + output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids) loss = output[0] lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr'] tensorboard_logs = {'train_loss': loss, 'lr': lr, @@ -378,8 +455,29 @@ def training_step(self, batch, batch_nb): return {'loss': loss, 'log': tensorboard_logs} def validation_step(self, batch, batch_nb): - input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch - output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends) + input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids, qids, aliases = batch + output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids) + if self.args.seq2seq: + logit_scores = output[1] + answer_score_indices = logit_scores.sort().indices + generated_ids = self.model.generate(input_ids=input_ids, attention_mask=input_mask, use_cache=True,) + answer_text = '' + best_answer_score = 0 + for i in answer_score_indices: + generated_answer_ids = generated_ids[answer_score_indices[i]] + generated_answer_ids[-1] = self.tokenizer.eos_token_id + index_of_eos_token = (generated_answer_ids == self.tokenizer.eos_token_id).nonzero()[0, 0].item() + generated_answer_ids = generated_answer_ids[1:index_of_eos_token] + answer_text = self.tokenizer.decode(generated_answer_ids) + if answer_text != '': + best_answer_score = logit_scores[answer_score_indices[i]] + break + f1_score = evaluation_utils.metric_max_over_ground_truths(evaluation_utils.f1_score, answer_text, aliases) + em_score = evaluation_utils.metric_max_over_ground_truths(evaluation_utils.exact_match_score, answer_text, aliases) + return {'vloss': output[0], 'vem': generated_answer_ids.new_zeros([1]).float(), + 'qids': [qids], 'answer_scores': [best_answer_score], + 'f1': [f1_score], 'em': [em_score]} + loss, start_logits, end_logits = output[:3] answers = self.decode(input_ids, start_logits, end_logits) @@ -453,8 +551,8 @@ def decode(self, input_ids, start_logits, end_logits): answers.append({'text': text, 'score': score}) return answers - def sync_list_across_gpus(self, l, device, dtype): - l_tensor = torch.tensor(l, device=device, dtype=dtype) + def sync_list_across_gpus(self, list_to_sync, device, dtype): + l_tensor = torch.tensor(list_to_sync, device=device, dtype=dtype) gather_l_tensor = [torch.ones_like(l_tensor) for _ in range(self.trainer.world_size)] torch.distributed.all_gather(gather_l_tensor, l_tensor) return torch.cat(gather_l_tensor).tolist() @@ -499,8 +597,11 @@ def validation_end(self, outputs): return {'avg_val_loss': avg_loss, 'log': logs, 'progress_bar': logs} def test_step(self, batch, batch_nb): - input_ids, input_mask, segment_ids, subword_starts, subword_ends, qids, aliases = batch - output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends) + input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids, qids, aliases = batch + output = self.forward(input_ids, input_mask, segment_ids, subword_starts, subword_ends, answer_token_ids) + if self.args.seq2seq: + raise NotImplemented + loss, start_logits, end_logits = output[:3] answers = self.decode(input_ids, start_logits, end_logits) @@ -528,21 +629,14 @@ def test_end(self, outputs): return {'count': len(qid_to_answer_text)} - def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None): - optimizer.step() - optimizer.zero_grad() - self.scheduler.step(self.global_step) - def configure_optimizers(self): def lr_lambda(current_step): if current_step < self.args.warmup: return float(current_step) / float(max(1, self.args.warmup)) return max(0.0, float(self.args.steps - current_step) / float(max(1, self.args.steps - self.args.warmup))) optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr) - self.scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) # scheduler is not saved in the checkpoint, but global_step is, which is enough to restart - self.scheduler.step(self.global_step) - - return optimizer + scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] @pl.data_loader def train_dataloader(self): @@ -554,7 +648,7 @@ def train_dataloader(self): max_num_answers=self.args.max_num_answers, max_question_len=self.args.max_question_len, ignore_seq_with_no_answers=self.args.ignore_seq_with_no_answers) - sampler = torch.utils.data.distributed.DistributedSampler(dataset) if self.trainer.use_ddp else None + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) if self.trainer.use_ddp else None dl = DataLoader(dataset, batch_size=1, shuffle=(sampler is None), num_workers=self.args.num_workers, sampler=sampler, collate_fn=TriviaQADataset.collate_one_doc_and_lists) @@ -571,8 +665,8 @@ def val_dataloader(self): max_num_answers=self.args.max_num_answers, max_question_len=self.args.max_question_len, ignore_seq_with_no_answers=False) # evaluation data should keep all examples - sampler = torch.utils.data.distributed.DistributedSampler(dataset) if self.trainer.use_ddp else None - dl = DataLoader(dataset, batch_size=1, shuffle=(sampler is None), + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) if self.trainer.use_ddp else None + dl = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.args.num_workers, sampler=sampler, collate_fn=TriviaQADataset.collate_one_doc_and_lists) self.val_dataloader_object = dl @@ -599,7 +693,7 @@ def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel( model, device_ids=device_ids, - find_unused_parameters=True + find_unused_parameters=False ) return model @@ -610,11 +704,11 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--train_dataset", type=str, required=False, help="Path to the training squad-format") parser.add_argument("--dev_dataset", type=str, required=True, help="Path to the dev squad-format") parser.add_argument("--batch_size", type=int, default=8, help="Batch size") - parser.add_argument("--gpus", type=str, default='0', - help="Comma separated list of gpus. Default is gpu 0. To use CPU, use --gpus "" ") + parser.add_argument("--gpus", type=int, default=1, + help="Number of gpus. 0 for CPU") parser.add_argument("--warmup", type=int, default=200, help="Number of warmup steps") parser.add_argument("--lr", type=float, default=0.0001, help="Maximum learning rate") - parser.add_argument("--val_every", type=float, default=0.2, help="Number of training steps between validations") + parser.add_argument("--val_every", type=float, default=0.5, help="Number of training steps between validations") parser.add_argument("--val_percent_check", default=1.00, type=float, help='Percent of validation data used') parser.add_argument("--num_workers", type=int, default=4, help="Number of data loader workers") parser.add_argument("--seed", type=int, default=1234, help="Seed") @@ -636,7 +730,8 @@ def add_model_specific_args(parser, root_dir): help="Number of answer candidates. Used at decoding time") parser.add_argument("--max_answer_length", type=int, default=30, help="maximum num of wordpieces/answer. Used at decoding time") - parser.add_argument("--regular_softmax_loss", action='store_true', help="IF true, use regular softmax. Default is using ORed softmax loss") + parser.add_argument("--regular_softmax_loss", action='store_true', + help="IF true, use regular softmax. Default is using ORed softmax loss") parser.add_argument("--test", action='store_true', help="Test only, no training") parser.add_argument("--model_path", type=str, required=True, help="Path to the checkpoint directory") @@ -644,6 +739,9 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--attention_mode", type=str, choices=['tvm', 'sliding_chunks'], default='sliding_chunks', help='Which implementation of selfattention to use') parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32") + parser.add_argument("--seq2seq", action='store_true', help="Use an answer generation model") + parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from") + return parser @@ -667,28 +765,32 @@ def main(args): filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"), save_top_k=5, verbose=True, - monitor='avg_val_f1', - mode='max', + monitor='avg_val_loss', + # save_last=True, + mode='min', + period=-1, prefix='' ) - args.gpus = [int(x) for x in args.gpus.split(',')] if args.gpus is not "" else None # use CPU if no gpu provided print(args) train_set_size = 110648 # hardcode dataset size. Needed to compute number of steps for the lr scheduler - num_devices = 1 or len(args.gpus) - args.steps = args.epochs * train_set_size / (args.batch_size * num_devices) - print(f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * num_devices} <<<<<<<') + args.steps = args.epochs * train_set_size / (args.batch_size * max(args.gpus, 1)) + print(f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * args.gpus} <<<<<<<') - trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if args.gpus and (len(args.gpus) > 1) else None, - track_grad_norm=-1, max_nb_epochs=args.epochs, early_stop_callback=None, + trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if args.gpus and args.gpus > 1 else None, + track_grad_norm=-1, max_epochs=args.epochs, early_stop_callback=None, + replace_sampler_ddp=False, accumulate_grad_batches=args.batch_size, val_check_interval=args.val_every, + num_sanity_val_steps=2, + # check_val_every_n_epoch=2, val_percent_check=args.val_percent_check, test_percent_check=args.val_percent_check, logger=logger if not args.disable_checkpointing else False, checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False, show_progress_bar=not args.no_progress_bar, use_amp=not args.fp32, amp_level='O2', + resume_from_checkpoint=args.resume_ckpt, ) if not args.test: trainer.fit(model)