Skip to content

Commit 5a60b40

Browse files
committed
Make Ensemble variadic
1 parent 9b0a362 commit 5a60b40

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

torchmdnet/models/model.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -453,23 +453,18 @@ def __init__(self, modules: List[nn.Module], return_std: bool = False):
453453

454454
def forward(
455455
self,
456-
z: Tensor,
457-
pos: Tensor,
458-
batch: Optional[Tensor] = None,
459-
box: Optional[Tensor] = None,
460-
q: Optional[Tensor] = None,
461-
s: Optional[Tensor] = None,
462-
extra_args: Optional[Dict[str, Tensor]] = None,
456+
*args,
457+
**kwargs,
463458
):
459+
"""Average predictions over all models in the ensemble.
460+
The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble.
461+
"""
464462
y = []
465463
neg_dy = []
466464
for model in self:
467-
res = model(
468-
z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args
469-
)
465+
res = model(*args, **kwargs)
470466
y.append(res[0])
471467
neg_dy.append(res[1])
472-
473468
y = torch.stack(y)
474469
neg_dy = torch.stack(neg_dy)
475470
y_mean = torch.mean(y, axis=0)

0 commit comments

Comments
 (0)