You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -440,12 +442,14 @@ class Ensemble(torch.nn.ModuleList):
440
442
441
443
Args:
442
444
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
445
+
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
0 commit comments