Skip to content

Commit d76b62c

Browse files
committed
add shape test
1 parent 62f30cf commit d76b62c

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

pytorch_forecasting/tests/test_all_estimators_v2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import lightning.pytorch as pl
77
from lightning.pytorch.callbacks import EarlyStopping
88
from lightning.pytorch.loggers import TensorBoardLogger
9+
import torch
910
import torch.nn as nn
1011

1112
from pytorch_forecasting.tests.test_all_estimators import (
@@ -77,6 +78,19 @@ def _integration(
7778
test_outputs = trainer.test(net, dataloaders=test_dataloader)
7879
assert len(test_outputs) > 0
7980

81+
# todo: add the predict pipeline and make this test cleaner
82+
x, y = next(iter(test_dataloader))
83+
net.eval()
84+
with torch.no_grad():
85+
output = net(x)
86+
net.train()
87+
prediction = output["prediction"]
88+
n_dims = len(prediction.shape)
89+
assert n_dims == 3, (
90+
f"Prediction output must be 3D, but got {n_dims}D tensor "
91+
f"with shape {output.shape}"
92+
)
93+
8094
shutil.rmtree(tmp_path, ignore_errors=True)
8195

8296

0 commit comments

Comments
 (0)