Skip to content

Commit 5ba3db9

Browse files
committed
returning the standard deviation is now optional
1 parent 94f94d9 commit 5ba3db9

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

torchmdnet/models/model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,22 +139,24 @@ def create_model(args, prior_model=None, mean=None, std=None):
139139
return model
140140

141141

142-
def load_model(filepath, args=None, device="cpu", **kwargs):
142+
def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs):
143143
"""Load a model from a checkpoint file.
144144
145145
If a list of paths is given, an :py:mod:`Ensemble` model is returned.
146146
Args:
147147
filepath (str or list): Path to the checkpoint file or a list of paths.
148148
args (dict, optional): Arguments for the model. Defaults to None.
149149
device (str, optional): Device on which the model should be loaded. Defaults to "cpu".
150+
return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False.
150151
**kwargs: Extra keyword arguments for the model.
151152
152153
Returns:
153154
nn.Module: An instance of the TorchMD_Net model.
154155
"""
155156
if isinstance(filepath, (list, tuple)):
156157
return Ensemble(
157-
[load_model(f, args=args, device=device, **kwargs) for f in filepath]
158+
[load_model(f, args=args, device=device, **kwargs) for f in filepath],
159+
return_std=return_std,
158160
)
159161

160162
ckpt = torch.load(filepath, map_location="cpu")
@@ -440,12 +442,14 @@ class Ensemble(torch.nn.ModuleList):
440442
441443
Args:
442444
modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over.
445+
return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy).
443446
"""
444447

445-
def __init__(self, modules: List[nn.Module]):
448+
def __init__(self, modules: List[nn.Module], return_std: bool = False):
446449
for module in modules:
447450
assert isinstance(module, TorchMD_Net)
448451
super().__init__(modules)
452+
self.return_std = return_std
449453

450454
def forward(
451455
self,
@@ -472,4 +476,8 @@ def forward(
472476
neg_dy_mean = torch.mean(neg_dy, axis=0)
473477
y_std = torch.std(y, axis=0)
474478
neg_dy_std = torch.std(neg_dy, axis=0)
475-
return y_mean, neg_dy_mean, y_std, neg_dy_std
479+
480+
if self.return_std:
481+
return y_mean, neg_dy_mean, y_std, neg_dy_std
482+
else:
483+
return y_mean, neg_dy_mean

0 commit comments

Comments
 (0)