-
Notifications
You must be signed in to change notification settings - Fork 23
Add Spin/Charge System Conditioning to PET #1040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JonathanSchmidt1
wants to merge
9
commits into
metatensor:main
Choose a base branch
from
JonathanSchmidt1:only_system_conditioning
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+514
−6
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
f81dfbd
add charge and spin to dataset
JonathanSchmidt1 d6bfd3e
add spin/charge hyperparameters to systemhypers
JonathanSchmidt1 6e619e4
add spin/charge hyperparameters to systemhypers
JonathanSchmidt1 831530a
small fixes
JonathanSchmidt1 b0ca0d8
add new checkpoint
JonathanSchmidt1 f6ff254
new checkpoint file
JonathanSchmidt1 d6e5871
revert toml change
JonathanSchmidt1 68acbe2
add checkpoints for classifier and llpr
JonathanSchmidt1 847b8f5
add checkpoints for classifier and llpr
JonathanSchmidt1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+28.7 KB
src/metatrain/experimental/classifier/tests/checkpoints/model-v2_trainer-v1.ckpt.gz
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| """System-level conditioning embeddings for charge and spin.""" | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class SystemConditioningEmbedding(torch.nn.Module): | ||
| """Embeds per-system charge and spin multiplicity into per-atom features. | ||
|
|
||
| Each system in a batch can have a different total charge and spin multiplicity. | ||
| These are embedded via learned lookup tables, concatenated, and projected down | ||
| to the model's feature dimension. The resulting per-system embedding is broadcast | ||
| to all atoms belonging to that system. | ||
|
|
||
| :param d_out: Output embedding dimension (should match d_node). | ||
| :param max_charge: Maximum absolute charge value. Supports charges in | ||
| the range ``[-max_charge, +max_charge]``. | ||
| :param max_spin: Maximum spin multiplicity (2S+1). Supports values in | ||
| the range ``[1, max_spin]``. | ||
| """ | ||
|
|
||
| def __init__(self, d_out: int, max_charge: int = 10, max_spin: int = 10): | ||
| super().__init__() | ||
| self.max_charge = max_charge | ||
| self.max_spin = max_spin | ||
| d_inner = d_out | ||
| self.charge_embedding = torch.nn.Embedding(2 * max_charge + 1, d_inner) | ||
| self.spin_embedding = torch.nn.Embedding(max_spin, d_inner) | ||
| gate = torch.nn.Linear(d_inner, d_out) | ||
| torch.nn.init.zeros_(gate.weight) | ||
| torch.nn.init.zeros_(gate.bias) | ||
| self.project = torch.nn.Sequential( | ||
| torch.nn.Linear(2 * d_inner, d_inner), | ||
| torch.nn.SiLU(), | ||
| gate, | ||
| ) | ||
|
|
||
| def validate(self, charge: torch.Tensor, spin: torch.Tensor) -> None: | ||
| """Check that charge and spin values are within the supported range. | ||
|
|
||
| Call this outside of ``torch.compile`` regions to get descriptive errors. | ||
|
|
||
| :param charge: Per-system total charge, shape ``[n_systems]``. | ||
| :param spin: Per-system spin multiplicity, shape ``[n_systems]``. | ||
| """ | ||
| if (charge < -self.max_charge).any() or (charge > self.max_charge).any(): | ||
| raise ValueError( | ||
| f"charge values must be in [{-self.max_charge}, " | ||
| f"{self.max_charge}], got min={charge.min().item()}, " | ||
| f"max={charge.max().item()}. Increase max_charge in " | ||
| f"model hypers to support wider charge ranges." | ||
| ) | ||
| if (spin < 1).any() or (spin > self.max_spin).any(): | ||
| raise ValueError( | ||
| f"spin multiplicity values must be in [1, " | ||
| f"{self.max_spin}], got min={spin.min().item()}, " | ||
| f"max={spin.max().item()}. Increase max_spin in " | ||
| f"model hypers to support higher spin multiplicities." | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| charge: torch.Tensor, | ||
| spin: torch.Tensor, | ||
| system_indices: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Compute per-atom conditioning features from per-system charge and spin. | ||
|
|
||
| :param charge: Per-system total charge, shape ``[n_systems]``, integer. | ||
| :param spin: Per-system spin multiplicity (2S+1), shape ``[n_systems]``, | ||
| integer >= 1. | ||
| :param system_indices: Maps each atom to its system index, | ||
| shape ``[n_atoms]``. | ||
| :return: Per-atom conditioning features, shape ``[n_atoms, d_out]``. | ||
| """ | ||
| c_emb = self.charge_embedding(charge + self.max_charge) # [n_systems, d_out] | ||
| s_emb = self.spin_embedding(spin - 1) # [n_systems, d_out] | ||
| system_emb = self.project(torch.cat([c_emb, s_emb], dim=-1)) # [n_systems, d] | ||
| return system_emb[system_indices] # [n_atoms, d_out] |
Binary file not shown.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@frostedoyster is this needed? From the last ml-devel meeting I understood that in principle changes to an architecture's checkpoint shouldn't need a version upgrade on the wrapper.
I think what needs to be done is to simply upgrade the checkpoint file with the new file, without any renaming, but let's see what Filippo says.
(same for llpr)