Skip to content

Commit f21f3de

Browse files
committed
Update docstring
1 parent 365edea commit f21f3de

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchmdnet/models/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,12 @@ def forward(
461461
):
462462
"""Average predictions over all models in the ensemble.
463463
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
464+
Args:
465+
*args: Positional arguments to forward to the models.
466+
**kwargs: Keyword arguments to forward to the models.
467+
Returns:
468+
Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy).
469+
464470
"""
465471
y = []
466472
neg_dy = []

0 commit comments

Comments
 (0)