diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..015d629f7 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -378,6 +378,10 @@ def forward( assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch + # trick to incorporate SPICE pqs + # set charge: true in yaml ((?) currently I do it) + q = extra_args["pq"] + if self.derivative: pos.requires_grad_(True) # run the potentially wrapped representation model diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e7..2b830ce19 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -239,7 +239,8 @@ def forward( # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - else: + # if not atom-wise, make atom-wise (pq is already atom-wise) + if z.shape != q.shape: q = q[batch] zp = z if self.static_shapes: