Skip to content

Commit 3093b9f

Browse files
authored
[ENH] xLSTMTime implementation (#1709)
### Description This PR tries to implement xLSTMTime based on this [paper](https://arxiv.org/pdf/2407.10240) see also `sktime` issue [#6793](sktime/sktime#6793)
1 parent abc4562 commit 3093b9f

File tree

17 files changed

+1238
-5
lines changed

17 files changed

+1238
-5
lines changed

docs/source/models.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ and you should take into account. Here is an overview over the pros and cons of
3030
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
3131
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
3232
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4
33-
:py:class:`~pytorch_forecasting.model.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
33+
:py:class:`~pytorch_forecasting.models.tide.TiDEModel`, "x", "x", "x", "", "", "", "", "x", "", 3
34+
:py:class:`~pytorch_forecasting.models.xlstm.xLSTMTime`, "x", "x", "x", "", "", "", "", "x", "", 3
3435

3536
.. [#deepvar] Accounting for correlations using a multivariate loss function which converts the network into a DeepVAR model.
3637

pytorch_forecasting/layers/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
Architectural deep learning layers from `nn.Module`.
33
"""
44

5-
from pytorch_forecasting.layers._attention import AttentionLayer, FullAttention
5+
from pytorch_forecasting.layers._attention import (
6+
AttentionLayer,
7+
FullAttention,
8+
TriangularCausalMask,
9+
)
10+
from pytorch_forecasting.layers._decomposition import SeriesDecomposition
611
from pytorch_forecasting.layers._embeddings import (
712
DataEmbedding_inverted,
813
EnEmbedding,
@@ -15,15 +20,32 @@
1520
from pytorch_forecasting.layers._output._flatten_head import (
1621
FlattenHead,
1722
)
23+
from pytorch_forecasting.layers._recurrent._mlstm import (
24+
mLSTMCell,
25+
mLSTMLayer,
26+
mLSTMNetwork,
27+
)
28+
from pytorch_forecasting.layers._recurrent._slstm import (
29+
sLSTMCell,
30+
sLSTMLayer,
31+
sLSTMNetwork,
32+
)
1833

1934
__all__ = [
2035
"FullAttention",
21-
"TriangularCausalMask",
2236
"AttentionLayer",
37+
"TriangularCausalMask",
2338
"DataEmbedding_inverted",
2439
"EnEmbedding",
2540
"PositionalEmbedding",
2641
"Encoder",
2742
"EncoderLayer",
2843
"FlattenHead",
44+
"mLSTMCell",
45+
"mLSTMLayer",
46+
"mLSTMNetwork",
47+
"sLSTMCell",
48+
"sLSTMLayer",
49+
"sLSTMNetwork",
50+
"SeriesDecomposition",
2951
]

pytorch_forecasting/layers/_attention/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"""
44

55
from pytorch_forecasting.layers._attention._attention_layer import AttentionLayer
6-
from pytorch_forecasting.layers._attention._full_attention import FullAttention
6+
from pytorch_forecasting.layers._attention._full_attention import (
7+
FullAttention,
8+
TriangularCausalMask,
9+
)
710

8-
__all__ = ["AttentionLayer", "FullAttention"]
11+
__all__ = ["AttentionLayer", "FullAttention", "TriangularCausalMask"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Recurrent Layers for Pytorch-Forecasting"""
2+
3+
from pytorch_forecasting.layers._recurrent._mlstm import (
4+
mLSTMCell,
5+
mLSTMLayer,
6+
mLSTMNetwork,
7+
)
8+
from pytorch_forecasting.layers._recurrent._slstm import (
9+
sLSTMCell,
10+
sLSTMLayer,
11+
sLSTMNetwork,
12+
)
13+
14+
__all__ = [
15+
"mLSTMCell",
16+
"mLSTMLayer",
17+
"mLSTMNetwork",
18+
"sLSTMCell",
19+
"sLSTMLayer",
20+
"sLSTMNetwork",
21+
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""mLSTM layer"""
2+
3+
from pytorch_forecasting.layers._recurrent._mlstm.cell import mLSTMCell
4+
from pytorch_forecasting.layers._recurrent._mlstm.layer import mLSTMLayer
5+
from pytorch_forecasting.layers._recurrent._mlstm.network import mLSTMNetwork
6+
7+
__all__ = ["mLSTMCell", "mLSTMLayer", "mLSTMNetwork"]
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class mLSTMCell(nn.Module):
8+
"""Implements the Matrix Long Short-Term Memory (mLSTM) Cell.
9+
10+
Implements the mLSTM algorithm as described in the paper:
11+
(https://arxiv.org/pdf/2407.10240).
12+
13+
Parameters
14+
----------
15+
input_size : int
16+
Size of the input feature vector.
17+
hidden_size : int
18+
Number of hidden units in the LSTM cell.
19+
dropout : float, optional
20+
Dropout rate applied to inputs and hidden states, by default 0.2.
21+
layer_norm : bool, optional
22+
If True, apply Layer Normalization to gates and interactions, by default True.
23+
24+
Attributes
25+
----------
26+
Wq : nn.Linear
27+
Linear layer for computing the query vector.
28+
Wk : nn.Linear
29+
Linear layer for computing the key vector.
30+
Wv : nn.Linear
31+
Linear layer for computing the value vector.
32+
Wi : nn.Linear
33+
Linear layer for the input gate.
34+
Wf : nn.Linear
35+
Linear layer for the forget gate.
36+
Wo : nn.Linear
37+
Linear layer for the output gate.
38+
dropout : nn.Dropout
39+
Dropout regularization layer.
40+
ln_q, ln_k, ln_v, ln_i, ln_f, ln_o : nn.LayerNorm
41+
Optional layer normalization layers for respective computations.
42+
"""
43+
44+
def __init__(self, input_size, hidden_size, dropout=0.2, layer_norm=True):
45+
super().__init__()
46+
self.input_size = input_size
47+
self.hidden_size = hidden_size
48+
self.layer_norm = layer_norm
49+
50+
self.Wq = nn.Linear(input_size, hidden_size)
51+
self.Wk = nn.Linear(input_size, hidden_size)
52+
self.Wv = nn.Linear(input_size, hidden_size)
53+
54+
self.Wi = nn.Linear(input_size, hidden_size)
55+
self.Wf = nn.Linear(input_size, hidden_size)
56+
self.Wo = nn.Linear(input_size, hidden_size)
57+
58+
self.dropout = nn.Dropout(dropout)
59+
60+
if layer_norm:
61+
self.ln_q = nn.LayerNorm(hidden_size)
62+
self.ln_k = nn.LayerNorm(hidden_size)
63+
self.ln_v = nn.LayerNorm(hidden_size)
64+
self.ln_i = nn.LayerNorm(hidden_size)
65+
self.ln_f = nn.LayerNorm(hidden_size)
66+
self.ln_o = nn.LayerNorm(hidden_size)
67+
68+
self.sigmoid = nn.Sigmoid()
69+
self.tanh = nn.Tanh()
70+
71+
def forward(self, x, h_prev, c_prev, n_prev):
72+
"""Compute the next hidden, cell, and normalized states in the mLSTM cell.
73+
74+
Parameters
75+
----------
76+
x : torch.Tensor
77+
The number of features in the input.
78+
h_prev : torch.Tensor
79+
Previous hidden state
80+
c_prev : torch.Tensor
81+
Previous cell state
82+
n_prev : torch.Tensor
83+
Previous normalized state
84+
85+
Returns
86+
-------
87+
tuple of torch.Tensor:
88+
h : torch.Tensor
89+
Current hidden state
90+
c : torch.Tensor
91+
Current cell state
92+
n : torch.Tensor
93+
Current normalized state
94+
"""
95+
96+
batch_size = x.size(0)
97+
assert (
98+
x.dim() == 2
99+
), f"Input should be 2D (batch_size, input_size), got {x.dim()}D"
100+
assert h_prev.size() == (
101+
batch_size,
102+
self.hidden_size,
103+
), f"h_prev shape mismatch: {h_prev.size()}"
104+
assert c_prev.size() == (
105+
batch_size,
106+
self.hidden_size,
107+
), f"c_prev shape mismatch: {c_prev.size()}"
108+
assert n_prev.size() == (
109+
batch_size,
110+
self.hidden_size,
111+
), f"n_prev shape mismatch: {n_prev.size()}"
112+
113+
x = self.dropout(x)
114+
h_prev = self.dropout(h_prev)
115+
116+
q = self.Wq(x)
117+
k = self.Wk(x) / math.sqrt(self.hidden_size)
118+
v = self.Wv(x)
119+
120+
if self.layer_norm:
121+
q = self.ln_q(q)
122+
k = self.ln_k(k)
123+
v = self.ln_v(v)
124+
125+
i = self.sigmoid(self.ln_i(self.Wi(x)) if self.layer_norm else self.Wi(x))
126+
f = self.sigmoid(self.ln_f(self.Wf(x)) if self.layer_norm else self.Wf(x))
127+
o = self.sigmoid(self.ln_o(self.Wo(x)) if self.layer_norm else self.Wo(x))
128+
129+
k_expanded = k.unsqueeze(-1)
130+
v_expanded = v.unsqueeze(-2)
131+
132+
kv_interaction = k_expanded @ v_expanded
133+
134+
kv_sum = kv_interaction.sum(dim=1)
135+
136+
c = f * c_prev + i * kv_sum
137+
n = f * n_prev + i * k
138+
139+
epsilon = 1e-8
140+
normalized_n = n / (torch.norm(n, dim=-1, keepdim=True) + epsilon)
141+
h = o * self.tanh(c * normalized_n)
142+
143+
return h, c, n
144+
145+
def init_hidden(self, batch_size, device=None):
146+
"""
147+
Initialize hidden, cell, and normalization states.
148+
"""
149+
if device is None:
150+
device = next(self.parameters()).device
151+
shape = (batch_size, self.hidden_size)
152+
return (
153+
torch.zeros(shape, device=device),
154+
torch.zeros(shape, device=device),
155+
torch.zeros(shape, device=device),
156+
)

0 commit comments

Comments
 (0)