|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | + |
| 7 | +class mLSTMCell(nn.Module): |
| 8 | + """Implements the Matrix Long Short-Term Memory (mLSTM) Cell. |
| 9 | +
|
| 10 | + Implements the mLSTM algorithm as described in the paper: |
| 11 | + (https://arxiv.org/pdf/2407.10240). |
| 12 | +
|
| 13 | + Parameters |
| 14 | + ---------- |
| 15 | + input_size : int |
| 16 | + Size of the input feature vector. |
| 17 | + hidden_size : int |
| 18 | + Number of hidden units in the LSTM cell. |
| 19 | + dropout : float, optional |
| 20 | + Dropout rate applied to inputs and hidden states, by default 0.2. |
| 21 | + layer_norm : bool, optional |
| 22 | + If True, apply Layer Normalization to gates and interactions, by default True. |
| 23 | +
|
| 24 | + Attributes |
| 25 | + ---------- |
| 26 | + Wq : nn.Linear |
| 27 | + Linear layer for computing the query vector. |
| 28 | + Wk : nn.Linear |
| 29 | + Linear layer for computing the key vector. |
| 30 | + Wv : nn.Linear |
| 31 | + Linear layer for computing the value vector. |
| 32 | + Wi : nn.Linear |
| 33 | + Linear layer for the input gate. |
| 34 | + Wf : nn.Linear |
| 35 | + Linear layer for the forget gate. |
| 36 | + Wo : nn.Linear |
| 37 | + Linear layer for the output gate. |
| 38 | + dropout : nn.Dropout |
| 39 | + Dropout regularization layer. |
| 40 | + ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm |
| 41 | + Optional layer normalization layers for respective computations. |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True): |
| 45 | + super().__init__() |
| 46 | + self.input_size = input_size |
| 47 | + self.hidden_size = hidden_size |
| 48 | + self.layer_norm = layer_norm |
| 49 | + |
| 50 | + self.Wq = nn.Linear(input_size, hidden_size) |
| 51 | + self.Wk = nn.Linear(input_size, hidden_size) |
| 52 | + self.Wv = nn.Linear(input_size, hidden_size) |
| 53 | + |
| 54 | + self.Wi = nn.Linear(input_size, hidden_size) |
| 55 | + self.Wf = nn.Linear(input_size, hidden_size) |
| 56 | + self.Wo = nn.Linear(input_size, hidden_size) |
| 57 | + |
| 58 | + self.dropout = nn.Dropout(dropout) |
| 59 | + |
| 60 | + if layer_norm: |
| 61 | + self.ln_q = nn.LayerNorm(hidden_size) |
| 62 | + self.ln_k = nn.LayerNorm(hidden_size) |
| 63 | + self.ln_v = nn.LayerNorm(hidden_size) |
| 64 | + self.ln_i = nn.LayerNorm(hidden_size) |
| 65 | + self.ln_f = nn.LayerNorm(hidden_size) |
| 66 | + self.ln_o = nn.LayerNorm(hidden_size) |
| 67 | + |
| 68 | + self.sigmoid = nn.Sigmoid() |
| 69 | + self.tanh = nn.Tanh() |
| 70 | + |
| 71 | + def forward(self, x, h_prev, c_prev, n_prev): |
| 72 | + """Compute the next hidden, cell, and normalized states in the mLSTM cell. |
| 73 | +
|
| 74 | + Parameters |
| 75 | + ---------- |
| 76 | + x : torch.Tensor |
| 77 | + The number of features in the input. |
| 78 | + h_prev : torch.Tensor |
| 79 | + Previous hidden state |
| 80 | + c_prev : torch.Tensor |
| 81 | + Previous cell state |
| 82 | + n_prev : torch.Tensor |
| 83 | + Previous normalized state |
| 84 | +
|
| 85 | + Returns |
| 86 | + ------- |
| 87 | + tuple of torch.Tensor: |
| 88 | + h : torch.Tensor |
| 89 | + Current hidden state |
| 90 | + c : torch.Tensor |
| 91 | + Current cell state |
| 92 | + n : torch.Tensor |
| 93 | + Current normalized state |
| 94 | + """ |
| 95 | + |
| 96 | + batch_size = x.size(0) |
| 97 | + assert ( |
| 98 | + x.dim() == 2 |
| 99 | + ), f"Input should be 2D (batch_size, input_size), got {x.dim()}D" |
| 100 | + assert h_prev.size() == ( |
| 101 | + batch_size, |
| 102 | + self.hidden_size, |
| 103 | + ), f"h_prev shape mismatch: {h_prev.size()}" |
| 104 | + assert c_prev.size() == ( |
| 105 | + batch_size, |
| 106 | + self.hidden_size, |
| 107 | + ), f"c_prev shape mismatch: {c_prev.size()}" |
| 108 | + assert n_prev.size() == ( |
| 109 | + batch_size, |
| 110 | + self.hidden_size, |
| 111 | + ), f"n_prev shape mismatch: {n_prev.size()}" |
| 112 | + |
| 113 | + x = self.dropout(x) |
| 114 | + h_prev = self.dropout(h_prev) |
| 115 | + |
| 116 | + q = self.Wq(x) |
| 117 | + k = self.Wk(x) / math.sqrt(self.hidden_size) |
| 118 | + v = self.Wv(x) |
| 119 | + |
| 120 | + if self.layer_norm: |
| 121 | + q = self.ln_q(q) |
| 122 | + k = self.ln_k(k) |
| 123 | + v = self.ln_v(v) |
| 124 | + |
| 125 | + i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x)) |
| 126 | + f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x)) |
| 127 | + o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x)) |
| 128 | + |
| 129 | + k_expanded = k.unsqueeze(-1) |
| 130 | + v_expanded = v.unsqueeze(-2) |
| 131 | + |
| 132 | + kv_interaction = k_expanded @ v_expanded |
| 133 | + |
| 134 | + kv_sum = kv_interaction.sum(dim=1) |
| 135 | + |
| 136 | + c = f * c_prev + i * kv_sum |
| 137 | + n = f * n_prev + i * k |
| 138 | + |
| 139 | + epsilon = 1e-8 |
| 140 | + normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon) |
| 141 | + h = o * self.tanh(c * normalized_n) |
| 142 | + |
| 143 | + return h, c, n |
| 144 | + |
| 145 | + def init_hidden(self, batch_size, device=None): |
| 146 | + """ |
| 147 | + Initialize hidden, cell, and normalization states. |
| 148 | + """ |
| 149 | + if device is None: |
| 150 | + device = next(self.parameters()).device |
| 151 | + shape = (batch_size, self.hidden_size) |
| 152 | + return ( |
| 153 | + torch.zeros(shape, device=device), |
| 154 | + torch.zeros(shape, device=device), |
| 155 | + torch.zeros(shape, device=device), |
| 156 | + ) |
0 commit comments