diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 29aeb24f5..033a2e297 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -10,6 +10,7 @@ ) from pytorch_forecasting.models.baseline import Baseline from pytorch_forecasting.models.deepar import DeepAR +from pytorch_forecasting.models.informer import Informer from pytorch_forecasting.models.mlp import DecoderMLP from pytorch_forecasting.models.nbeats import NBeats from pytorch_forecasting.models.nhits import NHiTS @@ -37,4 +38,5 @@ "MultiEmbedding", "DecoderMLP", "TiDEModel", + "Informer", ] diff --git a/pytorch_forecasting/models/informer/__init__.py b/pytorch_forecasting/models/informer/__init__.py new file mode 100644 index 000000000..655803f22 --- /dev/null +++ b/pytorch_forecasting/models/informer/__init__.py @@ -0,0 +1,7 @@ +""" +Informer Transformer for Long Sequence Time-Series Forecasting. +""" + +from pytorch_forecasting.models.informer._informer import Informer + +__all__ = ["Informer"] diff --git a/pytorch_forecasting/models/informer/_informer.py b/pytorch_forecasting/models/informer/_informer.py new file mode 100644 index 000000000..5bb31fc23 --- /dev/null +++ b/pytorch_forecasting/models/informer/_informer.py @@ -0,0 +1,187 @@ +""" +Informer Transformer for Long Sequence Time-Series Forecasting. +""" + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.base import BaseModel +from pytorch_forecasting.models.informer.sub_modules import ( + AttentionLayer, + ConvLayer, + DataEmbedding, + Decoder, + DecoderLayer, + Encoder, + EncoderLayer, + ProbAttention, +) +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class Informer(BaseModel): + def __init__( + self, + encoder_input: int = 5, + decoder_input: int = 10, + out_channels: int = 3, + seq_len: int = 20, + label_len: int = 4, + out_len: int = 10, + task_name: str = "forecasting", + factor: int = 5, + d_model: int = 512, + n_heads: int = 8, + encoder_layers: Union[int, List[int]] = 3, + decoder_layers: int = 2, + d_ff: int = 512, + dropout: int = 0.0, + embed: str = "fixed", + freq: str = "h", + activation: str = "gelu", + output_attention: bool = False, + loss: MultiHorizonMetric = None, + distil: bool = True, + mix: bool = True, + logging_metrics: Optional[nn.ModuleList] = None, + **kwargs, + ): + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MAE() + self.save_hyperparameters() + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + self.enc_embedding = DataEmbedding( + self.encoder_input, self.d_model, self.embed, self.freq, self.dropout + ) + self.dec_embedding = DataEmbedding( + self.decoder_input, self.d_model, self.embed, self.freq, self.dropout + ) + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + ProbAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.d_model, + self.n_heads, + ), + self.d_model, + self.d_ff, + dropout=self.dropout, + activation=self.activation, + ) + for l in range(self.encoder_layers) + ], + ( + [ConvLayer(self.d_model) for l in range(self.encoder_layers - 1)] + if self.distil and ("forecast" in self.task_name) + else None + ), + norm_layer=torch.nn.LayerNorm(self.d_model), + ) + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + ProbAttention( + True, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.d_model, + self.n_heads, + ), + AttentionLayer( + ProbAttention( + False, + self.factor, + attention_dropout=self.dropout, + output_attention=False, + ), + self.d_model, + self.n_heads, + ), + self.d_model, + self.d_ff, + dropout=self.dropout, + activation=self.activation, + ) + for l in range(self.decoder_layers) + ], + norm_layer=torch.nn.LayerNorm(self.d_model), + projection=nn.Linear(self.d_model, self.out_channels, bias=True), + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + Informer + """ # noqa: E501 + new_kwargs = { + "seq_len": dataset.max_prediction_length, + "encoder_input": dataset.max_encoder_length, + } + new_kwargs.update(kwargs) + + # create class and return + return super().from_dataset( + dataset, + **new_kwargs, + ) + + def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + enc_out = self.enc_embedding(x_enc, x_mark_enc) + dec_out = self.dec_embedding(x_dec, x_mark_dec) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) + + return dec_out # [B, L, D] + + def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): + # Normalization + mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E + x_enc = x_enc - mean_enc + std_enc = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5 + ).detach() # B x 1 x E + x_enc = x_enc / std_enc + + enc_out = self.enc_embedding(x_enc, x_mark_enc) + dec_out = self.dec_embedding(x_dec, x_mark_dec) + enc_out, attns = self.encoder(enc_out, attn_mask=None) + + dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) + + dec_out = dec_out * std_enc + mean_enc + return dec_out # [B, L, D] + + def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): + if self.task_name == "long_term_forecast": + dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out[:, -self.pred_len :, :] # [B, L, D] + if self.task_name == "short_term_forecast": + dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) + return dec_out[:, -self.pred_len :, :] # [B, L, D] + return None diff --git a/pytorch_forecasting/models/informer/sub_modules.py b/pytorch_forecasting/models/informer/sub_modules.py new file mode 100644 index 000000000..c162aedb0 --- /dev/null +++ b/pytorch_forecasting/models/informer/sub_modules.py @@ -0,0 +1,430 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + for layer in self.layers: + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x + + +class DecoderLayer(nn.Module): + def __init__( + self, + self_attention, + cross_attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu", + ): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): + x = x + self.dropout( + self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + ) + x = self.norm1(x) + + x = x + self.dropout( + self.cross_attention( + x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate( + zip(self.attn_layers, self.conv_layers) + ): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=2, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class ProbAttention(nn.Module): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + # real U = U_part(factor*ln(L_k))*L_q + index_sample = torch.randint(L_K, (L_Q, sample_k)) + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, : + ] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + # requires that L_Q == L_V, i.e. for self-attention only + assert L_Q == L_V + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) + attns[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ] = attn + return context_in, attns + else: + return context_in, None + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2, 1) + keys = keys.transpose(2, 1) + values = values.transpose(2, 1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype("int").item() # c*ln(L_q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) + + # add scale factor + scale = self.scale or 1.0 / math.sqrt(D) + if scale is not None: + scores_top = scores_top * scale + # get the context + context = self._get_initial_context(values, L_Q) + # update the context with selected top_k queries + context, attn = self._update_context( + context, values, scores_top, index, L_Q, attn_mask + ) + + return context.contiguous(), attn + + +class ProbMask: + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, keys, values, attn_mask, tau=tau, delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + if x_mark is None: + x = self.value_embedding(x) + self.position_embedding(x) + else: + x = ( + self.value_embedding(x) + + self.temporal_embedding(x_mark) + + self.position_embedding(x) + ) + return self.dropout(x) + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type="timeF", freq="h"): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.tokenConv = nn.Conv1d( + in_channels=c_in, + out_channels=d_model, + kernel_size=3, + padding=padding, + padding_mode="circular", + bias=False, + ) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode="fan_in", nonlinearity="leaky_relu" + ) + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + return self.pe[:, : x.size(1)] + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type="fixed", freq="h"): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding + if freq == "t": + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + minute_x = ( + self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 + ) + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() diff --git a/tests/test_models/test_informer.py b/tests/test_models/test_informer.py new file mode 100644 index 000000000..e2e665575 --- /dev/null +++ b/tests/test_models/test_informer.py @@ -0,0 +1,99 @@ +import pickle +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import pytest + +from pytorch_forecasting.models import Informer +from pytorch_forecasting.utils._dependencies import _get_installed_packages + + +def test_integration(dataloaders_fixed_window_without_covariates, tmp_path): + train_dataloader = dataloaders_fixed_window_without_covariates["train"] + val_dataloader = dataloaders_fixed_window_without_covariates["val"] + test_dataloader = dataloaders_fixed_window_without_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=2, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + ) + + net = Informer.from_dataset( + train_dataloader.dataset, + learning_rate=0.15, + factor=5, + n_heads=8, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = Informer.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + + +@pytest.fixture(scope="session") +def model(dataloaders_fixed_window_without_covariates): + dataset = dataloaders_fixed_window_without_covariates["train"].dataset + net = Informer.from_dataset( + dataset, + learning_rate=0.15, + factor=5, + n_heads=8, + ) + return net + + +def test_pickle(model): + pkl = pickle.dumps(model) + pickle.loads(pkl) # noqa: S301 + + +@pytest.mark.skipif( + "matplotlib" not in _get_installed_packages(), + reason="skip test if required package matplotlib not installed", +) +def test_interpretation(model, dataloaders_fixed_window_without_covariates): + raw_predictions = model.predict( + dataloaders_fixed_window_without_covariates["val"], + mode="raw", + return_x=True, + fast_dev_run=True, + ) + model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0)