|
20 | 20 |
|
21 | 21 | @tf.keras.utils.register_keras_serializable(package="Addons") |
22 | 22 | class MultiHeadAttention(tf.keras.layers.Layer): |
23 | | - r""" |
24 | | - MultiHead Attention layer. |
| 23 | + r"""MultiHead Attention layer. |
25 | 24 |
|
26 | 25 | Defines the MultiHead Attention operation as described in |
27 | 26 | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) which takes |
28 | 27 | in the tensors `query`, `key`, and `value`, and returns the dot-product attention |
29 | 28 | between them: |
30 | 29 |
|
31 | | - ```python |
32 | | - mha = MultiHeadAttention(head_size=128, num_heads=12) |
| 30 | + ```python |
| 31 | + mha = MultiHeadAttention(head_size=128, num_heads=12) |
33 | 32 |
|
34 | | - query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth) |
35 | | - key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth) |
36 | | - value = tf.random.uniform((32, 15, 400)) # (batch_size, key_elements, value_depth) |
| 33 | + query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth) |
| 34 | + key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth) |
| 35 | + value = tf.random.uniform((32, 15, 400)) # (batch_size, key_elements, value_depth) |
37 | 36 |
|
38 | | - attention = mha([query, key, value]) # (batch_size, query_elements, value_depth) |
39 | | - ``` |
| 37 | + attention = mha([query, key, value]) # (batch_size, query_elements, value_depth) |
| 38 | + ``` |
40 | 39 |
|
41 | 40 | If `value` is not given then internally `value = key` will be used: |
42 | 41 |
|
43 | | - ```python |
44 | | - mha = MultiHeadAttention(head_size=128, num_heads=12) |
| 42 | + ```python |
| 43 | + mha = MultiHeadAttention(head_size=128, num_heads=12) |
45 | 44 |
|
46 | | - query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth) |
47 | | - key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth) |
| 45 | + query = tf.random.uniform((32, 20, 200)) # (batch_size, query_elements, query_depth) |
| 46 | + key = tf.random.uniform((32, 15, 300)) # (batch_size, key_elements, key_depth) |
48 | 47 |
|
49 | | - attention = mha([query, key]) # (batch_size, query_elements, key_depth) |
50 | | - ``` |
| 48 | + attention = mha([query, key]) # (batch_size, query_elements, key_depth) |
| 49 | + ``` |
51 | 50 |
|
52 | 51 | Arguments: |
53 | 52 | head_size: int, dimensionality of the `query`, `key` and `value` tensors |
54 | | - after the linear transformation. |
| 53 | + after the linear transformation. |
55 | 54 | num_heads: int, number of attention heads. |
56 | 55 | output_size: int, dimensionality of the output space, if `None` then the |
57 | | - input dimension of |
58 | | - `value` or `key` will be used, default `None`. |
| 56 | + input dimension of `value` or `key` will be used, |
| 57 | + default `None`. |
59 | 58 | dropout: float, `rate` parameter for the dropout layer that is |
60 | | - applied to attention after softmax, |
| 59 | + applied to attention after softmax, |
61 | 60 | default `0`. |
62 | 61 | use_projection_bias: bool, whether to use a bias term after the linear |
63 | | - output projection. |
| 62 | + output projection. |
64 | 63 | return_attn_coef: bool, if `True`, return the attention coefficients as |
65 | | - an additional output argument. |
| 64 | + an additional output argument. |
66 | 65 | kernel_initializer: initializer, initializer for the kernel weights. |
67 | 66 | kernel_regularizer: regularizer, regularizer for the kernel weights. |
68 | 67 | kernel_constraint: constraint, constraint for the kernel weights. |
|
0 commit comments