Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 42 additions & 37 deletions src/metatrain/experimental/flashmd/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,48 +472,53 @@ def train(
)
)

val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype, device=device
)
predictions = evaluate_model(
model,
systems,
{key: train_targets[key] for key in targets.keys()},
is_training=False,
)
with torch.set_grad_enabled(
any(target_info.gradients for target_info in train_targets.values())
): # keep gradients on if any of the targets require them
val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype, device=device
)
predictions = evaluate_model(
model,
systems,
{key: train_targets[key] for key in targets.keys()},
is_training=False,
)

# average by the number of atoms
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(targets, systems, per_structure_targets)
val_loss_batch = loss_fn(predictions, targets, extra_data)
# average by the number of atoms
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(
targets, systems, per_structure_targets
)
val_loss_batch = loss_fn(predictions, targets, extra_data)

if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()
if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()

scaled_predictions = (model.module if is_distributed else model).scaler(
systems, predictions
)
scaled_targets = (model.module if is_distributed else model).scaler(
systems, targets
)
val_rmse_calculator.update(
scaled_predictions, scaled_targets, extra_data
)
if self.hypers["log_mae"]:
val_mae_calculator.update(
scaled_predictions = (
model.module if is_distributed else model
).scaler(systems, predictions)
scaled_targets = (model.module if is_distributed else model).scaler(
systems, targets
)
val_rmse_calculator.update(
scaled_predictions, scaled_targets, extra_data
)
if self.hypers["log_mae"]:
val_mae_calculator.update(
scaled_predictions, scaled_targets, extra_data
)

finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
Expand Down
79 changes: 42 additions & 37 deletions src/metatrain/experimental/mace/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,48 +476,53 @@ def train(
)
)

val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype, device=device
)
predictions = evaluate_model(
model,
systems,
{key: train_targets[key] for key in targets.keys()},
is_training=False,
)
with torch.set_grad_enabled(
any(target_info.gradients for target_info in train_targets.values())
): # keep gradients on if any of the targets require them
val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype, device=device
)
predictions = evaluate_model(
model,
systems,
{key: train_targets[key] for key in targets.keys()},
is_training=False,
)

# average by the number of atoms
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(targets, systems, per_structure_targets)
val_loss_batch = loss_fn(predictions, targets, extra_data)
# average by the number of atoms
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(
targets, systems, per_structure_targets
)
val_loss_batch = loss_fn(predictions, targets, extra_data)

if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()
if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()

scaled_predictions = (model.module if is_distributed else model).scaler(
systems, predictions
)
scaled_targets = (model.module if is_distributed else model).scaler(
systems, targets
)
val_rmse_calculator.update(
scaled_predictions, scaled_targets, extra_data
)
if self.hypers["log_mae"]:
val_mae_calculator.update(
scaled_predictions = (
model.module if is_distributed else model
).scaler(systems, predictions)
scaled_targets = (model.module if is_distributed else model).scaler(
systems, targets
)
val_rmse_calculator.update(
scaled_predictions, scaled_targets, extra_data
)
if self.hypers["log_mae"]:
val_mae_calculator.update(
scaled_predictions, scaled_targets, extra_data
)

lr_scheduler.step(metrics=val_loss)

Expand Down
92 changes: 49 additions & 43 deletions src/metatrain/llpr/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,55 +462,61 @@ def train(
)
)

val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue
with torch.set_grad_enabled(
any(target_info.gradients for target_info in train_targets.values())
): # keep gradients on if any of the targets require them
val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, device=device
)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype
)
predictions = evaluate_model(
model,
systems,
requested_outputs,
is_training=False,
)
val_loss_batch = loss_fn(predictions, targets, extra_data)

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, device=device
)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype
)
predictions = evaluate_model(
model,
systems,
requested_outputs,
is_training=False,
)
val_loss_batch = loss_fn(predictions, targets, extra_data)
if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()

if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(
targets, systems, per_structure_targets
)

predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(targets, systems, per_structure_targets)
targets = _drop_gradient_blocks(targets)
val_rmse_calculator.update(predictions, targets)
if self.hypers["log_mae"]:
val_mae_calculator.update(predictions, targets)

targets = _drop_gradient_blocks(targets)
val_rmse_calculator.update(predictions, targets)
finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)
if self.hypers["log_mae"]:
val_mae_calculator.update(predictions, targets)

finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
)
if self.hypers["log_mae"]:
finalized_val_info.update(
val_mae_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
is_distributed=is_distributed,
device=device,
finalized_val_info.update(
val_mae_calculator.finalize(
not_per_atom=["positions_gradients"]
+ per_structure_targets,
is_distributed=is_distributed,
device=device,
)
)
)

# Now we log the information:
finalized_train_info = {
Expand Down
79 changes: 42 additions & 37 deletions src/metatrain/pet/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,48 +449,53 @@ def train(
)
)

val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype, device=device
)
predictions = evaluate_model(
model,
systems,
{key: train_targets[key] for key in targets.keys()},
is_training=False,
)
with torch.set_grad_enabled(
any(target_info.gradients for target_info in train_targets.values())
): # keep gradients on if any of the targets require them
val_loss = 0.0
for batch in val_dataloader:
# Skip None batches (those outside batch_atom_bounds)
if should_skip_batch(batch, is_distributed, device):
continue

systems, targets, extra_data = unpack_batch(batch)
systems, targets, extra_data = batch_to(
systems, targets, extra_data, dtype=dtype, device=device
)
predictions = evaluate_model(
model,
systems,
{key: train_targets[key] for key in targets.keys()},
is_training=False,
)

# average by the number of atoms
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(targets, systems, per_structure_targets)
val_loss_batch = loss_fn(predictions, targets, extra_data)
# average by the number of atoms
predictions = average_by_num_atoms(
predictions, systems, per_structure_targets
)
targets = average_by_num_atoms(
targets, systems, per_structure_targets
)
val_loss_batch = loss_fn(predictions, targets, extra_data)

if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()
if is_distributed:
# sum the loss over all processes
torch.distributed.all_reduce(val_loss_batch)
val_loss += val_loss_batch.item()

scaled_predictions = (model.module if is_distributed else model).scaler(
systems, predictions
)
scaled_targets = (model.module if is_distributed else model).scaler(
systems, targets
)
val_rmse_calculator.update(
scaled_predictions, scaled_targets, extra_data
)
if self.hypers["log_mae"]:
val_mae_calculator.update(
scaled_predictions = (
model.module if is_distributed else model
).scaler(systems, predictions)
scaled_targets = (model.module if is_distributed else model).scaler(
systems, targets
)
val_rmse_calculator.update(
scaled_predictions, scaled_targets, extra_data
)
if self.hypers["log_mae"]:
val_mae_calculator.update(
scaled_predictions, scaled_targets, extra_data
)

finalized_val_info = val_rmse_calculator.finalize(
not_per_atom=["positions_gradients"] + per_structure_targets,
Expand Down
Loading
Loading