diff --git a/src/metatrain/experimental/flashmd/trainer.py b/src/metatrain/experimental/flashmd/trainer.py index 990e94ef6..eafcb0c39 100644 --- a/src/metatrain/experimental/flashmd/trainer.py +++ b/src/metatrain/experimental/flashmd/trainer.py @@ -472,48 +472,53 @@ def train( ) ) - val_loss = 0.0 - for batch in val_dataloader: - # Skip None batches (those outside batch_atom_bounds) - if should_skip_batch(batch, is_distributed, device): - continue - - systems, targets, extra_data = unpack_batch(batch) - systems, targets, extra_data = batch_to( - systems, targets, extra_data, dtype=dtype, device=device - ) - predictions = evaluate_model( - model, - systems, - {key: train_targets[key] for key in targets.keys()}, - is_training=False, - ) + with torch.set_grad_enabled( + any(target_info.gradients for target_info in train_targets.values()) + ): # keep gradients on if any of the targets require them + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) - # average by the number of atoms - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - val_loss_batch = loss_fn(predictions, targets, extra_data) + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + val_loss_batch = loss_fn(predictions, targets, extra_data) - if is_distributed: - # sum the loss over all processes - torch.distributed.all_reduce(val_loss_batch) - val_loss += val_loss_batch.item() + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() - scaled_predictions = (model.module if is_distributed else model).scaler( - systems, predictions - ) - scaled_targets = (model.module if is_distributed else model).scaler( - systems, targets - ) - val_rmse_calculator.update( - scaled_predictions, scaled_targets, extra_data - ) - if self.hypers["log_mae"]: - val_mae_calculator.update( + scaled_predictions = ( + model.module if is_distributed else model + ).scaler(systems, predictions) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( scaled_predictions, scaled_targets, extra_data ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) finalized_val_info = val_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, diff --git a/src/metatrain/experimental/mace/trainer.py b/src/metatrain/experimental/mace/trainer.py index 62f723d35..56777b1a7 100644 --- a/src/metatrain/experimental/mace/trainer.py +++ b/src/metatrain/experimental/mace/trainer.py @@ -476,48 +476,53 @@ def train( ) ) - val_loss = 0.0 - for batch in val_dataloader: - # Skip None batches (those outside batch_atom_bounds) - if should_skip_batch(batch, is_distributed, device): - continue - - systems, targets, extra_data = unpack_batch(batch) - systems, targets, extra_data = batch_to( - systems, targets, extra_data, dtype=dtype, device=device - ) - predictions = evaluate_model( - model, - systems, - {key: train_targets[key] for key in targets.keys()}, - is_training=False, - ) + with torch.set_grad_enabled( + any(target_info.gradients for target_info in train_targets.values()) + ): # keep gradients on if any of the targets require them + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) - # average by the number of atoms - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - val_loss_batch = loss_fn(predictions, targets, extra_data) + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + val_loss_batch = loss_fn(predictions, targets, extra_data) - if is_distributed: - # sum the loss over all processes - torch.distributed.all_reduce(val_loss_batch) - val_loss += val_loss_batch.item() + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() - scaled_predictions = (model.module if is_distributed else model).scaler( - systems, predictions - ) - scaled_targets = (model.module if is_distributed else model).scaler( - systems, targets - ) - val_rmse_calculator.update( - scaled_predictions, scaled_targets, extra_data - ) - if self.hypers["log_mae"]: - val_mae_calculator.update( + scaled_predictions = ( + model.module if is_distributed else model + ).scaler(systems, predictions) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( scaled_predictions, scaled_targets, extra_data ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) lr_scheduler.step(metrics=val_loss) diff --git a/src/metatrain/llpr/trainer.py b/src/metatrain/llpr/trainer.py index a00bf0683..2364cb4d9 100644 --- a/src/metatrain/llpr/trainer.py +++ b/src/metatrain/llpr/trainer.py @@ -462,55 +462,61 @@ def train( ) ) - val_loss = 0.0 - for batch in val_dataloader: - # Skip None batches (those outside batch_atom_bounds) - if should_skip_batch(batch, is_distributed, device): - continue + with torch.set_grad_enabled( + any(target_info.gradients for target_info in train_targets.values()) + ): # keep gradients on if any of the targets require them + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, device=device + ) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype + ) + predictions = evaluate_model( + model, + systems, + requested_outputs, + is_training=False, + ) + val_loss_batch = loss_fn(predictions, targets, extra_data) - systems, targets, extra_data = unpack_batch(batch) - systems, targets, extra_data = batch_to( - systems, targets, extra_data, device=device - ) - systems, targets, extra_data = batch_to( - systems, targets, extra_data, dtype=dtype - ) - predictions = evaluate_model( - model, - systems, - requested_outputs, - is_training=False, - ) - val_loss_batch = loss_fn(predictions, targets, extra_data) + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() - if is_distributed: - # sum the loss over all processes - torch.distributed.all_reduce(val_loss_batch) - val_loss += val_loss_batch.item() + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) + targets = _drop_gradient_blocks(targets) + val_rmse_calculator.update(predictions, targets) + if self.hypers["log_mae"]: + val_mae_calculator.update(predictions, targets) - targets = _drop_gradient_blocks(targets) - val_rmse_calculator.update(predictions, targets) + finalized_val_info = val_rmse_calculator.finalize( + not_per_atom=["positions_gradients"] + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) if self.hypers["log_mae"]: - val_mae_calculator.update(predictions, targets) - - finalized_val_info = val_rmse_calculator.finalize( - not_per_atom=["positions_gradients"] + per_structure_targets, - is_distributed=is_distributed, - device=device, - ) - if self.hypers["log_mae"]: - finalized_val_info.update( - val_mae_calculator.finalize( - not_per_atom=["positions_gradients"] + per_structure_targets, - is_distributed=is_distributed, - device=device, + finalized_val_info.update( + val_mae_calculator.finalize( + not_per_atom=["positions_gradients"] + + per_structure_targets, + is_distributed=is_distributed, + device=device, + ) ) - ) # Now we log the information: finalized_train_info = { diff --git a/src/metatrain/pet/trainer.py b/src/metatrain/pet/trainer.py index 8194adc7d..2b6f25925 100644 --- a/src/metatrain/pet/trainer.py +++ b/src/metatrain/pet/trainer.py @@ -449,48 +449,53 @@ def train( ) ) - val_loss = 0.0 - for batch in val_dataloader: - # Skip None batches (those outside batch_atom_bounds) - if should_skip_batch(batch, is_distributed, device): - continue - - systems, targets, extra_data = unpack_batch(batch) - systems, targets, extra_data = batch_to( - systems, targets, extra_data, dtype=dtype, device=device - ) - predictions = evaluate_model( - model, - systems, - {key: train_targets[key] for key in targets.keys()}, - is_training=False, - ) + with torch.set_grad_enabled( + any(target_info.gradients for target_info in train_targets.values()) + ): # keep gradients on if any of the targets require them + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) - # average by the number of atoms - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) - val_loss_batch = loss_fn(predictions, targets, extra_data) + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) + val_loss_batch = loss_fn(predictions, targets, extra_data) - if is_distributed: - # sum the loss over all processes - torch.distributed.all_reduce(val_loss_batch) - val_loss += val_loss_batch.item() + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() - scaled_predictions = (model.module if is_distributed else model).scaler( - systems, predictions - ) - scaled_targets = (model.module if is_distributed else model).scaler( - systems, targets - ) - val_rmse_calculator.update( - scaled_predictions, scaled_targets, extra_data - ) - if self.hypers["log_mae"]: - val_mae_calculator.update( + scaled_predictions = ( + model.module if is_distributed else model + ).scaler(systems, predictions) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( scaled_predictions, scaled_targets, extra_data ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) finalized_val_info = val_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets, diff --git a/src/metatrain/soap_bpnn/trainer.py b/src/metatrain/soap_bpnn/trainer.py index 5779e3068..f4715f762 100644 --- a/src/metatrain/soap_bpnn/trainer.py +++ b/src/metatrain/soap_bpnn/trainer.py @@ -412,49 +412,54 @@ def train( ) ) - val_loss = 0.0 - for batch in val_dataloader: - # Skip None batches (those outside batch_atom_bounds) - if should_skip_batch(batch, is_distributed, device): - continue - - systems, targets, extra_data = unpack_batch(batch) - systems, targets, extra_data = batch_to( - systems, targets, extra_data, dtype=dtype, device=device - ) + with torch.set_grad_enabled( + any(target_info.gradients for target_info in train_targets.values()) + ): # keep gradients on if any of the targets require them + val_loss = 0.0 + for batch in val_dataloader: + # Skip None batches (those outside batch_atom_bounds) + if should_skip_batch(batch, is_distributed, device): + continue + + systems, targets, extra_data = unpack_batch(batch) + systems, targets, extra_data = batch_to( + systems, targets, extra_data, dtype=dtype, device=device + ) - predictions = evaluate_model( - model, - systems, - {key: train_targets[key] for key in targets.keys()}, - is_training=False, - ) + predictions = evaluate_model( + model, + systems, + {key: train_targets[key] for key in targets.keys()}, + is_training=False, + ) - # average by the number of atoms - predictions = average_by_num_atoms( - predictions, systems, per_structure_targets - ) - targets = average_by_num_atoms(targets, systems, per_structure_targets) + # average by the number of atoms + predictions = average_by_num_atoms( + predictions, systems, per_structure_targets + ) + targets = average_by_num_atoms( + targets, systems, per_structure_targets + ) - val_loss_batch = loss_fn(predictions, targets, extra_data) + val_loss_batch = loss_fn(predictions, targets, extra_data) - if is_distributed: - # sum the loss over all processes - torch.distributed.all_reduce(val_loss_batch) - val_loss += val_loss_batch.item() - scaled_predictions = (model.module if is_distributed else model).scaler( - systems, predictions - ) - scaled_targets = (model.module if is_distributed else model).scaler( - systems, targets - ) - val_rmse_calculator.update( - scaled_predictions, scaled_targets, extra_data - ) - if self.hypers["log_mae"]: - val_mae_calculator.update( + if is_distributed: + # sum the loss over all processes + torch.distributed.all_reduce(val_loss_batch) + val_loss += val_loss_batch.item() + scaled_predictions = ( + model.module if is_distributed else model + ).scaler(systems, predictions) + scaled_targets = (model.module if is_distributed else model).scaler( + systems, targets + ) + val_rmse_calculator.update( scaled_predictions, scaled_targets, extra_data ) + if self.hypers["log_mae"]: + val_mae_calculator.update( + scaled_predictions, scaled_targets, extra_data + ) finalized_val_info = val_rmse_calculator.finalize( not_per_atom=["positions_gradients"] + per_structure_targets,