Conversation
|
|
||
| class LongformerSelfAttention(nn.Module): | ||
| def __init__(self, config, layer_id): | ||
| def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): |
There was a problem hiding this comment.
T5 attention module is slightly different from conventional ones. It doesn't have bias, nor does it scale the attention score according to attention head dimension before softmax. See this list for more details.
There was a problem hiding this comment.
In the default option, bias=True, attention_dim_scale=True. This should just fall back to regular self-attention.
There was a problem hiding this comment.
Please add your comment to the code.
| selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 | ||
| # concat to attn_weights | ||
| # (bsz, seq_len, num_heads, extra attention count + 2*window+1) | ||
| # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) |
There was a problem hiding this comment.
changed annotation to be consistent with related annotations below
| self.attention_mode = config.attention_mode | ||
| self.autoregressive = config.autoregressive | ||
|
|
||
| if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: |
There was a problem hiding this comment.
In T5, the position bias is shared across layers. This is done by letting the first layer compute the position bias, then pass it on to the remaining layers.
There was a problem hiding this comment.
Good catch. Please write this comment in the code for more readablity.
| if output_attentions: | ||
| outputs = outputs + (attn_weights,) | ||
| if self.has_relative_attention_bias: | ||
| outputs = outputs + (position_bias,) |
There was a problem hiding this comment.
this is equivalent to the old output form, when self.has_relative_attention_bias=False
| return outputs | ||
|
|
||
|
|
||
| def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
There was a problem hiding this comment.
I was considering moving this to longformer_encoder_decoder, but that will lead to cycle import, so this has to be here.
| layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
|
||
|
|
||
| class LongformerT5Config(T5Config): |
There was a problem hiding this comment.
You can see, we are getting many highly-similar config classes as we extending to other transformer models. If you like, we can simplify this by using Mixin. It will be like having another Mixin class containing all the longformer specific settings, and the LongformerT5Config class will inherit both the Mixin class and T5Config.
There was a problem hiding this comment.
I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)
| ) | ||
| self.output = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||
|
|
||
| def forward( |
There was a problem hiding this comment.
An alternative I considered was to let this class inherit LongformerSelfAttention. But eventually, I decided not to do so. The interfaces of the two classes are quite different. What we have here, i.e., making LongformerSelfAttention a member of the LongformerSelfAttentionForT5, is probably less confusing than the althernative.
| # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) | ||
| attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
|
||
| if position_bias is None and self.has_relative_attention_bias: |
There was a problem hiding this comment.
since the sliding window already has put the attention score in the form of [q_(i) * k_(i-w), q_(i) * k_(i-w+1), ..., q_(i) * k_(i), ... , q_(i) * k_(i+w)] the relative position is simply arange
There was a problem hiding this comment.
please move this comment to the code.
There was a problem hiding this comment.
nit: Maybe also move this block of code to a separate function
| perm_global_position_bias = attn_weights.new_zeros( | ||
| bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads | ||
| ) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) | ||
| if extra_attention_mask is not None: |
There was a problem hiding this comment.
Global position bias is a bit more complex. We first get the memory position from extra_attention_mask_nonzeros, then compute the query position using arrange. Their diff is the relative position. But this "sparse" one vector for each global token in the batch. So we later put it back into the shape of (bsz, max_num_extra_indices_per_batch, ...) using the index information from selection_padding_mask_nonzeros
There was a problem hiding this comment.
didn't review this part yet.
ibeltagy
left a comment
There was a problem hiding this comment.
Looks great, thank you.
I left a few small comments. I didn't review the global attention part yet, will do later, maybe today.
| self.attention_mode = config.attention_mode | ||
| self.autoregressive = config.autoregressive | ||
|
|
||
| if hasattr(config, "relative_attention_num_buckets") and layer_id == 0: |
There was a problem hiding this comment.
Good catch. Please write this comment in the code for more readablity.
|
|
||
| class LongformerSelfAttention(nn.Module): | ||
| def __init__(self, config, layer_id): | ||
| def __init__(self, config, layer_id, bias=True, attention_dim_scale=True): |
There was a problem hiding this comment.
Please add your comment to the code.
| selected_attn_weights[selection_padding_mask_zeros[0], :, :, selection_padding_mask_zeros[1]] = -10000 | ||
| # concat to attn_weights | ||
| # (bsz, seq_len, num_heads, extra attention count + 2*window+1) | ||
| # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) |
| # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) | ||
| attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
|
||
| if position_bias is None and self.has_relative_attention_bias: |
There was a problem hiding this comment.
please move this comment to the code.
| # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch + 2*window+1) | ||
| attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) | ||
|
|
||
| if position_bias is None and self.has_relative_attention_bias: |
There was a problem hiding this comment.
nit: Maybe also move this block of code to a separate function
| perm_global_position_bias = attn_weights.new_zeros( | ||
| bsz, max_num_extra_indices_per_batch, seq_len, self.num_heads | ||
| ) # (bsz, max_num_extra_indices_per_batch, seq_len, num_heads) | ||
| if extra_attention_mask is not None: |
There was a problem hiding this comment.
didn't review this part yet.
| base_model_name_or_path="t5-small", | ||
| ) | ||
| self._run_test( | ||
| INPUT_TEXT="It begins with the Great Hungerer. It ends in utter darkeness.", |
| def test_outout(self): | ||
| self._run_test( | ||
| INPUT_TEXT="Hello world!", | ||
| long_model_name_or_path="/net/nfs2.s2-research/haokunl/exp_files/model_artifacts/t5/longt5-small-4096", |
There was a problem hiding this comment.
It would be great if this test works without the local model. One way to do so is to call create_long_model in the text to convert t5 to long, then test it. It will make the test slower but easier to run.
| layer.layer[0].SelfAttention = LongformerSelfAttentionForT5(config, layer_id=i) | ||
|
|
||
|
|
||
| class LongformerT5Config(T5Config): |
There was a problem hiding this comment.
I don't have strong feelings about this. You decide (as long as we don't change the interface of the released code)
| from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations | ||
| from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv | ||
| from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv | ||
| from longformer.sliding_chunks import ( |
There was a problem hiding this comment.
It is fine that your dev env changed the file format. I know it doesn't change the code but I will feel more comfortable if you run a small test to make sure the new code produces the same output as the previous one for Longformer.
| # in T5 attention_probs_dropout_prob is dropout_rate | ||
| config.attention_probs_dropout_prob = config.dropout_rate | ||
| config.attention_window = [attention_window] * config.num_hidden_layers | ||
| config.attention_dilation = [1] * config.num_hidden_layers |
There was a problem hiding this comment.
when increasing the model length we probably want to increase number of relative position buckets as well config.relative_attention_num_buckets
Based on @AkshitaB 's work (#149), this PR extends Longformer to T5. It also adds a test to check if the Longformer T5 produces the same output as the standard T5 on short input texts, as suggested by @ibeltagy in this comment
A quick thing about code style: I'm not sure if this repo has selected any formatter previously. I didn't find
dev-requirements.txt. So I continue to use the black formatter in my default setting. It automatically re-formats the file whenever I save it. You may notice changes like'->", or breaking a long line into multiple lines. I hope it doesn't bother you too much.