-
Notifications
You must be signed in to change notification settings - Fork 23
dpa3 architecture #725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
littlepeachs
wants to merge
24
commits into
metatensor:main
Choose a base branch
from
littlepeachs:DPA3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
dpa3 architecture #725
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
81ea25b
add dpa3 architecture
13b8500
fix dpa3 interface & add dependency
385170b
add unit tests
d23aaa1
fix unit tests
3162c91
Merge branch 'main' into DPA3
ceriottm 243fc48
DPA3 review revision 0924
96b6e11
add code owner for DPA3
69895b3
Lint and add documentation
frostedoyster a60b92f
Merge branch 'main' into DPA3
frostedoyster 5204317
Fix tests
frostedoyster c536d42
Update to new loss code
frostedoyster c49063f
Fix tests
frostedoyster e676aaa
Fix tests
frostedoyster 53b5044
Fix docs
frostedoyster 9474893
Fix docs
frostedoyster ecfa4dd
Merge branch 'main' into DPA3
Luthaf 335a00b
Merge branch 'main' into DPA3
Luthaf 43e953c
Merge branch 'main' into DPA3
Luthaf 10fb6cc
Update dpa3 to follow with recent changes
Luthaf 001f4a4
Merge branch 'main' into DPA3
pfebrer 895e858
linting
pfebrer e2b1e34
Moved to shared tests (some still failing)
pfebrer f57414f
Make docs build
pfebrer 15ae9a0
Fix test typo
pfebrer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ jobs: | |
| - llpr | ||
| - mace | ||
| - nanopet | ||
| - dpa3 | ||
| - pet | ||
| - soap-bpnn | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from .model import DPA3 | ||
| from .trainer import Trainer | ||
|
|
||
|
|
||
| __model__ = DPA3 | ||
| __trainer__ = Trainer | ||
|
|
||
| __authors__ = [ | ||
| ("Duo Zhang <zhduodyx@pku.edu.cn>", "@duozhang"), | ||
| ] | ||
|
|
||
| __maintainers__ = [ | ||
| ("Duo Zhang <zhduodyx@pku.edu.cn>", "@duozhang"), | ||
| ("Wentao Li <liwt24@mails.tsinghua.edu.cn>", "@wentaoli"), | ||
| ] |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| """ | ||
| DPA3 (experimental) | ||
| ====================== | ||
|
|
||
| This is an interface to the DPA3 architecture described in https://arxiv.org/abs/2506.01686 | ||
| and implemented in deepmd-kit (https://github.com/deepmodeling/deepmd-kit). | ||
| """ | ||
|
|
||
| from typing import Literal, Optional | ||
|
|
||
| from typing_extensions import TypedDict | ||
|
|
||
| from metatrain.utils.additive import FixedCompositionWeights | ||
| from metatrain.utils.hypers import init_with_defaults | ||
| from metatrain.utils.loss import LossSpecification | ||
|
|
||
|
|
||
| ########################### | ||
| # MODEL HYPERPARAMETERS # | ||
| ########################### | ||
|
|
||
|
|
||
| class RepflowHypers(TypedDict): | ||
| n_dim: int = 128 | ||
| e_dim: int = 64 | ||
| a_dim: int = 32 | ||
| nlayers: int = 6 | ||
| e_rcut: float = 6.0 | ||
| e_rcut_smth: float = 5.3 | ||
| e_sel: int = 1200 | ||
| a_rcut: float = 4.0 | ||
| a_rcut_smth: float = 3.5 | ||
| a_sel: int = 300 | ||
| axis_neuron: int = 4 | ||
| skip_stat: bool = True | ||
| a_compress_rate: int = 1 | ||
| a_compress_e_rate: int = 2 | ||
| a_compress_use_split: bool = True | ||
| update_angle: bool = True | ||
| # TODO: what are the options here | ||
| update_style: str = "res_residual" | ||
| update_residual: float = 0.1 | ||
| # TODO: what are the options here | ||
| update_residual_init: str = "const" | ||
| smooth_edge_update: bool = True | ||
| use_dynamic_sel: bool = True | ||
| sel_reduce_factor: float = 10.0 | ||
|
|
||
|
|
||
| class DescriptorHypers(TypedDict): | ||
| # TODO: what are the options here | ||
| type: str = "dpa3" | ||
| repflow: RepflowHypers = init_with_defaults(RepflowHypers) | ||
| # TODO: what are the options here | ||
| activation_function: str = "custom_silu:10.0" | ||
| use_tebd_bias: bool = False | ||
| # TODO: what are the options here | ||
| precision: str = "float32" | ||
| concat_output_tebd: bool = False | ||
|
|
||
|
|
||
| class FittingNetHypers(TypedDict): | ||
| neuron: list[int] = [240, 240, 240] | ||
| resnet_dt: bool = True | ||
| seed: int = 1 | ||
| # TODO: what are the options here | ||
| precision: str = "float32" | ||
| # TODO: what are the options here | ||
| activation_function: str = "custom_silu:10.0" | ||
| # TODO: what are the options here | ||
| type: str = "ener" | ||
| numb_fparam: int = 0 | ||
| numb_aparam: int = 0 | ||
| dim_case_embd: int = 0 | ||
| trainable: bool = True | ||
| rcond: Optional[float] = None | ||
| atom_ener: list[float] = [] | ||
| use_aparam_as_mask: bool = False | ||
|
|
||
|
|
||
| class ModelHypers(TypedDict): | ||
| """Hyperparameters for the DPA3 model.""" | ||
|
|
||
| type_map: list[str] = ["H", "C", "N", "O"] | ||
|
|
||
| descriptor: DescriptorHypers = init_with_defaults(DescriptorHypers) | ||
| fitting_net: FittingNetHypers = init_with_defaults(FittingNetHypers) | ||
|
|
||
|
|
||
| ############################## | ||
| # TRAINER HYPERPARAMETERS # | ||
| ############################## | ||
|
|
||
|
|
||
| class TrainerHypers(TypedDict): | ||
| distributed: bool = False | ||
| """Whether to use distributed training""" | ||
| distributed_port: int = 39591 | ||
| """Port for DDP communication""" | ||
| batch_size: int = 8 | ||
| """The number of samples to use in each batch of training. This | ||
| hyperparameter controls the tradeoff between training speed and memory usage. In | ||
| general, larger batch sizes will lead to faster training, but might require more | ||
| memory.""" | ||
| num_epochs: int = 100 | ||
| """Number of epochs.""" | ||
| learning_rate: float = 0.001 | ||
| """Learning rate.""" | ||
|
|
||
| # TODO: update the scheduler or not | ||
| scheduler_patience: int = 100 | ||
| scheduler_factor: float = 0.8 | ||
|
|
||
| log_interval: int = 1 | ||
| """Interval to log metrics.""" | ||
| checkpoint_interval: int = 100 | ||
| """Interval to save checkpoints.""" | ||
| scale_targets: bool = True | ||
| """Normalize targets to unit std during training.""" | ||
| fixed_composition_weights: FixedCompositionWeights = {} | ||
| """Weights for atomic contributions. | ||
|
|
||
| This is passed to the ``fixed_weights`` argument of | ||
| :meth:`CompositionModel.train_model | ||
| <metatrain.utils.additive.composition.CompositionModel.train_model>`, | ||
| see its documentation to understand exactly what to pass here. | ||
| """ | ||
| # fixed_scaling_weights: FixedScalerWeights = {} | ||
| # """Weights for target scaling. | ||
|
|
||
| # This is passed to the ``fixed_weights`` argument of | ||
| # :meth:`Scaler.train_model <metatrain.utils.scaler.scaler.Scaler.train_model>`, | ||
| # see its documentation to understand exactly what to pass here. | ||
| # """ | ||
| per_structure_targets: list[str] = [] | ||
| """Targets to calculate per-structure losses.""" | ||
| # num_workers: Optional[int] = None | ||
| # """Number of workers for data loading. If not provided, it is set | ||
| # automatically.""" | ||
| log_mae: bool = False | ||
| """Log MAE alongside RMSE""" | ||
| log_separate_blocks: bool = False | ||
| """Log per-block error.""" | ||
| best_model_metric: Literal["rmse_prod", "mae_prod", "loss"] = "rmse_prod" | ||
| """Metric used to select best checkpoint (e.g., ``rmse_prod``)""" | ||
|
|
||
| loss: str | dict[str, LossSpecification] = "mse" | ||
| """This section describes the loss function to be used. See the | ||
| :ref:`loss-functions` for more details.""" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.