Skip to content

Commit 013a691

Browse files
committed
Update documentation
1 parent 5a60b40 commit 013a691

File tree

1 file changed

+52
-31
lines changed

1 file changed

+52
-31
lines changed

docs/source/models.rst

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -83,38 +83,9 @@ This is a minimal example of a custom training loop:
8383
optimizer.step()
8484
8585
86-
87-
88-
Loading a model for inference
89-
-----------------------------
90-
91-
Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.
92-
93-
.. code:: python
94-
95-
import torch
96-
from torchmdnet.models.model import load_model
97-
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
98-
model = load_model(checkpoint, derivative=True)
99-
# An arbitrary set of inputs for the model
100-
n_atoms = 10
101-
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
102-
z = zs[torch.randint(0, len(zs), (n_atoms,))]
103-
pos = torch.randn(len(z), 3)
104-
batch = torch.zeros(len(z), dtype=torch.long)
105-
106-
y, neg_dy = model(z, pos, batch)
107-
108-
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.
109-
110-
.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.
111-
112-
.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.
113-
114-
11586
.. _delta-learning:
11687
Training on relative energies
117-
-----------------------------
88+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
11889

11990
It might be useful to train the model on relative energies but then make the model produce total energies when running inference.
12091
TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model.
@@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the
12697
.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`.
12798

12899
Example
129-
~~~~~~~
100+
********
130101

131102
First we train a model with the :code:`remove_ref_energy` option turned on:
132103

@@ -151,6 +122,56 @@ Then we load the model for inference:
151122
batch = torch.zeros(len(z), dtype=torch.long)
152123
153124
y, neg_dy = model(z, pos, batch)
125+
126+
127+
Loading a model for inference
128+
-----------------------------
129+
130+
Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.
131+
132+
.. code:: python
133+
134+
import torch
135+
from torchmdnet.models.model import load_model
136+
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
137+
model = load_model(checkpoint, derivative=True)
138+
# An arbitrary set of inputs for the model
139+
n_atoms = 10
140+
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
141+
z = zs[torch.randint(0, len(zs), (n_atoms,))]
142+
pos = torch.randn(len(z), 3)
143+
batch = torch.zeros(len(z), dtype=torch.long)
144+
145+
y, neg_dy = model(z, pos, batch)
146+
147+
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.
148+
149+
.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.
150+
151+
.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.
152+
153+
154+
Model Ensembles
155+
---------------
156+
It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this:
157+
158+
.. code:: python
159+
160+
import torch
161+
from torchmdnet.models.model import load_model
162+
checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"]
163+
model_ensemble = load_model(checkpoints, return_std=True)
164+
y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)
165+
166+
167+
.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions.
168+
169+
170+
171+
.. autoclass:: torchmdnet.models.model.Ensemble
172+
:noindex:
173+
174+
154175

155176

156177

0 commit comments

Comments
 (0)