File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
pytorch_forecasting/tests Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change 6
6
import lightning .pytorch as pl
7
7
from lightning .pytorch .callbacks import EarlyStopping
8
8
from lightning .pytorch .loggers import TensorBoardLogger
9
+ import torch
9
10
import torch .nn as nn
10
11
11
12
from pytorch_forecasting .tests .test_all_estimators import (
@@ -77,6 +78,19 @@ def _integration(
77
78
test_outputs = trainer .test (net , dataloaders = test_dataloader )
78
79
assert len (test_outputs ) > 0
79
80
81
+ # todo: add the predict pipeline and make this test cleaner
82
+ x , y = next (iter (test_dataloader ))
83
+ net .eval ()
84
+ with torch .no_grad ():
85
+ output = net (x )
86
+ net .train ()
87
+ prediction = output ["prediction" ]
88
+ n_dims = len (prediction .shape )
89
+ assert n_dims == 3 , (
90
+ f"Prediction output must be 3D, but got { n_dims } D tensor "
91
+ f"with shape { output .shape } "
92
+ )
93
+
80
94
shutil .rmtree (tmp_path , ignore_errors = True )
81
95
82
96
You can’t perform that action at this time.
0 commit comments