diff --git a/src/single_struct_calculator.py b/src/single_struct_calculator.py index 8a29d88..473931c 100644 --- a/src/single_struct_calculator.py +++ b/src/single_struct_calculator.py @@ -43,7 +43,7 @@ def __init__( model = PETMLIPWrapper( model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES ) - if torch.cuda.is_available() and (torch.cuda.device_count() > 1): + if torch.cuda.is_available() and (torch.cuda.device_count() > 4): model = DataParallel(model) model = model.to(torch.device("cuda:0")) @@ -85,7 +85,7 @@ def forward(self, structure): graph = graph.to(self.device) if self.quadrature_order is None: - if torch.cuda.is_available() and (torch.cuda.device_count() > 1): + if torch.cuda.is_available() and (torch.cuda.device_count() > 4): self.model.module.augmentation = self.use_augmentation self.model.module.create_graph = False prediction_energy, prediction_forces = self.model([graph])