Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/cache_explanation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ $$

The query (`Q`), key (`K`), and value (`V`) matrices are projections from the input embeddings of shape `(b, h, T, d_head)`.

For causal attention, the mask prevents the model from attending to future tokens. Once a token is processed, its representation never changes with respect to future tokens, which means \\( K_{\text{past}} \\) and \\( V_{\text{past}} \\) can be cached and reused to compute the last token's representation.
For causal attention, the mask prevents the model from attending to future tokens. Once a token is processed, its representation never changes with respect to future tokens, which means $ K_{\text{past}} $ and $ V_{\text{past}} $ can be cached and reused to compute the last token's representation.

$$
\text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])
$$

At inference time, you only need the last token's query to compute the representation \\( x_t \\) that predicts the next token \\( t+1 \\). At each step, the new key and value vectors are **stored** in the cache and **appended** to the past keys and values.
At inference time, you only need the last token's query to compute the representation $ x_t $ that predicts the next token $ t+1 $. At each step, the new key and value vectors are **stored** in the cache and **appended** to the past keys and values.

$$
K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)
Expand All @@ -59,7 +59,7 @@ Refer to the table below to compare how caching improves efficiency.

| without caching | with caching |
|---|---|
| for each step, recompute all previous `K` and `V` | for each step, only compute current `K` and `V`
| for each step, recompute all previous `K` and `V` | for each step, only compute current `K` and `V` |
| attention cost per step is **quadratic** with sequence length | attention cost per step is **linear** with sequence length (memory grows linearly, but compute/token remains low) |

## Cache class
Expand Down
40 changes: 20 additions & 20 deletions docs/source/en/model_doc/reformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ found [here](https://github.com/google/trax/tree/master/trax/models/reformer).

Axial Positional Encodings were first implemented in Google's [trax library](https://github.com/google/trax/blob/4d99ad4965bab1deba227539758d59f0df0fef48/trax/layers/research/position_encodings.py#L29)
and developed by the authors of this model's paper. In models that are treating very long input sequences, the
conventional position id encodings store an embeddings vector of size \\(d\\) being the `config.hidden_size` for
every position \\(i, \ldots, n_s\\), with \\(n_s\\) being `config.max_embedding_size`. This means that having
a sequence length of \\(n_s = 2^{19} \approx 0.5M\\) and a `config.hidden_size` of \\(d = 2^{10} \approx 1000\\)
conventional position id encodings store an embeddings vector of size $d$ being the `config.hidden_size` for
every position $i, \ldots, n_s$, with $n_s$ being `config.max_embedding_size`. This means that having
a sequence length of $n_s = 2^{19} \approx 0.5M$ and a `config.hidden_size` of $d = 2^{10} \approx 1000$
would result in a position encoding matrix:

$$X_{i,j}, \text{ with } i \in \left[1,\ldots, d\right] \text{ and } j \in \left[1,\ldots, n_s\right]$$

which alone has over 500M parameters to store. Axial positional encodings factorize \\(X_{i,j}\\) into two matrices:
which alone has over 500M parameters to store. Axial positional encodings factorize $X_{i,j}$ into two matrices:

$$X^{1}_{i,j}, \text{ with } i \in \left[1,\ldots, d^1\right] \text{ and } j \in \left[1,\ldots, n_s^1\right]$$

Expand All @@ -76,16 +76,16 @@ X^{1}_{i, k}, & \text{if }\ i < d^1 \text{ with } k = j \mod n_s^1 \\
X^{2}_{i - d^1, l}, & \text{if } i \ge d^1 \text{ with } l = \lfloor\frac{j}{n_s^1}\rfloor
\end{cases}$$

Intuitively, this means that a position embedding vector \\(x_j \in \mathbb{R}^{d}\\) is now the composition of two
factorized embedding vectors: \\(x^1_{k, l} + x^2_{l, k}\\), where as the `config.max_embedding_size` dimension
\\(j\\) is factorized into \\(k \text{ and } l\\). This design ensures that each position embedding vector
\\(x_j\\) is unique.
Intuitively, this means that a position embedding vector $x_j \in \mathbb{R}^{d}$ is now the composition of two
factorized embedding vectors: $x^1_{k, l} + x^2_{l, k}$, where as the `config.max_embedding_size` dimension
$j$ is factorized into $k \text{ and } l$. This design ensures that each position embedding vector
$x_j$ is unique.

Using the above example again, axial position encoding with \\(d^1 = 2^9, d^2 = 2^9, n_s^1 = 2^9, n_s^2 = 2^{10}\\)
can drastically reduced the number of parameters from 500 000 000 to \\(2^{18} + 2^{19} \approx 780 000\\) parameters, this means 85% less memory usage.
Using the above example again, axial position encoding with $d^1 = 2^9, d^2 = 2^9, n_s^1 = 2^9, n_s^2 = 2^{10}$
can drastically reduced the number of parameters from 500 000 000 to $2^{18} + 2^{19} \approx 780 000$ parameters, this means 85% less memory usage.

In practice, the parameter `config.axial_pos_embds_dim` is set to a tuple \\((d^1, d^2)\\) which sum has to be
equal to `config.hidden_size` and `config.axial_pos_shape` is set to a tuple \\((n_s^1, n_s^2)\\) which
In practice, the parameter `config.axial_pos_embds_dim` is set to a tuple $(d^1, d^2)$ which sum has to be
equal to `config.hidden_size` and `config.axial_pos_shape` is set to a tuple $(n_s^1, n_s^2)$ which
product has to be equal to `config.max_embedding_size`, which during training has to be equal to the *sequence
length* of the `input_ids`.

Expand All @@ -107,19 +107,19 @@ neighboring chunks and `config.lsh_num_chunks_after` following neighboring chunk

For more information, see the [original Paper](https://huggingface.co/papers/2001.04451) or this great [blog post](https://www.pragmatic.ml/reformer-deep-dive/).

Note that `config.num_buckets` can also be factorized into a list \\((n_{\text{buckets}}^1,
n_{\text{buckets}}^2)\\). This way instead of assigning the query key embedding vectors to one of \\((1,\ldots,
n_{\text{buckets}})\\) they are assigned to one of \\((1-1,\ldots, n_{\text{buckets}}^1-1, \ldots,
1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)\\). This is crucial for very long sequences to
Note that `config.num_buckets` can also be factorized into a list $(n_{\text{buckets}}^1,
n_{\text{buckets}}^2)$. This way instead of assigning the query key embedding vectors to one of $(1,\ldots,
n_{\text{buckets}})$ they are assigned to one of $(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots,
1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)$. This is crucial for very long sequences to
save memory.

When training a model from scratch, it is recommended to leave `config.num_buckets=None`, so that depending on the
sequence length a good value for `num_buckets` is calculated on the fly. This value will then automatically be
saved in the config and should be reused for inference.

Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from
\\(\mathcal{O}(n_s \times n_s)\\) to \\(\mathcal{O}(n_s \times \log(n_s))\\), which usually represents the memory
and time bottleneck in a transformer model, with \\(n_s\\) being the sequence length.
$\mathcal{O}(n_s \times n_s)$ to $\mathcal{O}(n_s \times \log(n_s))$, which usually represents the memory
and time bottleneck in a transformer model, with $n_s$ being the sequence length.

### Local Self Attention

Expand All @@ -129,8 +129,8 @@ the key embedding vectors in its chunk and to the key embedding vectors of `conf
previous neighboring chunks and `config.local_num_chunks_after` following neighboring chunks.

Using Local self attention, the memory and time complexity of the query-key matmul operation can be reduced from
\\(\mathcal{O}(n_s \times n_s)\\) to \\(\mathcal{O}(n_s \times \log(n_s))\\), which usually represents the memory
and time bottleneck in a transformer model, with \\(n_s\\) being the sequence length.
$\mathcal{O}(n_s \times n_s)$ to $\mathcal{O}(n_s \times \log(n_s))$, which usually represents the memory
and time bottleneck in a transformer model, with $n_s$ being the sequence length.

### Training

Expand Down
18 changes: 9 additions & 9 deletions docs/source/en/model_doc/rwkv.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,43 +94,43 @@ In a traditional auto-regressive Transformer, attention is written as

$$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$

with \\(Q\\), \\(K\\) and \\(V\\) are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the matrix product with \\(V\\) to get the output \\(O\\) of the same shape as the others.
with $Q$, $K$ and $V$ are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product $QK^{T}$ then has shape `seq_len x seq_len` and we can take the matrix product with $V$ to get the output $O$ of the same shape as the others.

Replacing the softmax by its value gives:

$$O_{i} = \frac{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}} V_{j}}{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}}}$$

Note that the entries in \\(QK^{T}\\) corresponding to \\(j > i\\) are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones).
Note that the entries in $QK^{T}$ corresponding to $j > i$ are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones).

In comparison, the RWKV attention is given by

$$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$

where \\(R\\) is a new matrix called receptance by the author, \\(K\\) and \\(V\\) are still the key and value (\\(\sigma\\) here is the sigmoid function). \\(W\\) is a new vector that represents the position of the token and is given by
where $R$ is a new matrix called receptance by the author, $K$ and $V$ are still the key and value ($\sigma$ here is the sigmoid function). $W$ is a new vector that represents the position of the token and is given by

$$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$

with \\(u\\) and \\(w\\) learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have:
with $u$ and $w$ learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them $N_{i}$ and $D_{i}$ we have:

$$N_{i} = e^{u + K_{i}} V_{i} + \hat{N}_{i} \hbox{ where } \hat{N}_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$

so \\(\hat{N}_{i}\\) (called `numerator_state` in the code) satisfies
so $\hat{N}_{i}$ (called `numerator_state` in the code) satisfies

$$\hat{N}_{0} = 0 \hbox{ and } \hat{N}_{j+1} = e^{K_{j}} V_{j} + e^{w} \hat{N}_{j}$$

and

$$D_{i} = e^{u + K_{i}} + \hat{D}_{i} \hbox{ where } \hat{D}_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$

so \\(\hat{D}_{i}\\) (called `denominator_state` in the code) satisfies
so $\hat{D}_{i}$ (called `denominator_state` in the code) satisfies

$$\hat{D}_{0} = 0 \hbox{ and } \hat{D}_{j+1} = e^{K_{j}} + e^{w} \hat{D}_{j}$$

The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator:

$$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$

with \\(M\\) the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(\hat{N}\\)) and the denominator state (\\(\hat{D}\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use
with $M$ the maximum of all $x_{j}$. So here on top of saving the numerator state ($\hat{N}$) and the denominator state ($\hat{D}$) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use

$$\tilde{N}_{i} = e^{-M_{i}} \hat{N}_{i} \hbox{ and } \tilde{D}_{i} = e^{-M_{i}} \hat{D}_{i}$$

Expand All @@ -142,7 +142,7 @@ and

$$\tilde{D}_{0} = 0 \hbox{ and } \tilde{D}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{D}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$

and \\(M_{j+1} = q\\). With those, we can then compute
and $M_{j+1} = q$. With those, we can then compute

$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{N}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$

Expand All @@ -152,4 +152,4 @@ $$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{D}_{i} \hbox{ where } q = \max(

which finally gives us

$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$
$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$
12 changes: 6 additions & 6 deletions docs/source/en/perplexity.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ that the metric applies specifically to classical language models (sometimes cal
models) and is not well defined for masked language models like BERT (see [summary of the models](model_summary)).

Perplexity is defined as the exponentiated average negative log-likelihood of a sequence. If we have a tokenized
sequence \\(X = (x_0, x_1, \dots, x_t)\\), then the perplexity of \\(X\\) is,
sequence $X = (x_0, x_1, \dots, x_t)$, then the perplexity of $X$ is,

$$\text{PPL}(X) = \exp \left\{ {-\frac{1}{t}\sum_i^t \log p_\theta (x_i|x_{<i}) } \right\}$$
$$ \text{PPL}(X) = \exp\left\{ -\frac{1}{t}\sum_i^t \log p_\theta (x_i|x_{<i}) \right\} $$

where \\(\log p_\theta (x_i|x_{<i})\\) is the log-likelihood of the ith token conditioned on the preceding tokens \\(x_{<i}\\) according to our model. Intuitively, it can be thought of as an evaluation of the model's ability to predict uniformly among the set of specified tokens in a corpus. Importantly, this means that the tokenization procedure has a direct impact on a model's perplexity which should always be taken into consideration when comparing different models.
where $\log p_\theta (x_i|x_{<i})$ is the log-likelihood of the ith token conditioned on the preceding tokens $x_{<i}$ according to our model. Intuitively, it can be thought of as an evaluation of the model's ability to predict uniformly among the set of specified tokens in a corpus. Importantly, this means that the tokenization procedure has a direct impact on a model's perplexity which should always be taken into consideration when comparing different models.

This is also equivalent to the exponentiation of the cross-entropy between the data and model predictions. For more
intuition about perplexity and its relationship to Bits Per Character (BPC) and data compression, check out this
Expand All @@ -42,11 +42,11 @@ factorizing a sequence and conditioning on the entire preceding subsequence at e

When working with approximate models, however, we typically have a constraint on the number of tokens the model can
process. The largest version of [GPT-2](model_doc/gpt2), for example, has a fixed length of 1024 tokens, so we
cannot calculate \\(p_\theta(x_t|x_{<t})\\) directly when \\(t\\) is greater than 1024.
cannot calculate $p_\theta(x_t|x_{<t})$ directly when $t$ is greater than 1024.

Instead, the sequence is typically broken into subsequences equal to the model's maximum input size. If a model's max
input size is \\(k\\), we then approximate the likelihood of a token \\(x_t\\) by conditioning only on the
\\(k-1\\) tokens that precede it rather than the entire context. When evaluating the model's perplexity of a
input size is $k$, we then approximate the likelihood of a token $x_t$ by conditioning only on the
$k-1$ tokens that precede it rather than the entire context. When evaluating the model's perplexity of a
sequence, a tempting but suboptimal approach is to break the sequence into disjoint chunks and add up the decomposed
log-likelihoods of each segment independently.

Expand Down
Loading