You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/models.rst
+52-31Lines changed: 52 additions & 31 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -83,38 +83,9 @@ This is a minimal example of a custom training loop:
83
83
optimizer.step()
84
84
85
85
86
-
87
-
88
-
Loading a model for inference
89
-
-----------------------------
90
-
91
-
Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.
109
-
110
-
.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.
111
-
112
-
.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.
113
-
114
-
115
86
.. _delta-learning:
116
87
Training on relative energies
117
-
-----------------------------
88
+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
118
89
119
90
It might be useful to train the model on relative energies but then make the model produce total energies when running inference.
120
91
TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model.
@@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the
126
97
.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`.
127
98
128
99
Example
129
-
~~~~~~~
100
+
********
130
101
131
102
First we train a model with the :code:`remove_ref_energy` option turned on:
132
103
@@ -151,6 +122,56 @@ Then we load the model for inference:
151
122
batch = torch.zeros(len(z), dtype=torch.long)
152
123
153
124
y, neg_dy = model(z, pos, batch)
125
+
126
+
127
+
Loading a model for inference
128
+
-----------------------------
129
+
130
+
Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.
148
+
149
+
.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.
150
+
151
+
.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.
152
+
153
+
154
+
Model Ensembles
155
+
---------------
156
+
It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this:
.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions.
0 commit comments