-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathPytorch.html
More file actions
371 lines (356 loc) · 31.5 KB
/
Pytorch.html
File metadata and controls
371 lines (356 loc) · 31.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>5. PyTorch Models — FitSNAP documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
<script src="_static/js/html5shiv.min.js"></script>
<![endif]-->
<script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
<script src="_static/jquery.js"></script>
<script src="_static/underscore.js"></script>
<script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="_static/doctools.js"></script>
<script src="_static/sphinx_highlight.js"></script>
<script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
<script src="_static/js/theme.js"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="next" title="1. Contributing" href="Contributing.html" />
<link rel="prev" title="4. Linear Models" href="Linear.html" />
</head>
<body class="wy-body-for-nav">
<div class="wy-grid-for-nav">
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
<div class="wy-side-scroll">
<div class="wy-side-nav-search" >
<a href="index.html">
<img src="_static/FitSNAP.png" class="logo" alt="Logo"/>
</a>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
<p class="caption" role="heading"><span class="caption-text">User Guide</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="Introduction.html">1. Introduction</a></li>
<li class="toctree-l1"><a class="reference internal" href="Installation.html">2. Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="Run/index.html">3. Run FitSNAP</a></li>
<li class="toctree-l1"><a class="reference internal" href="Linear.html">4. Linear Models</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">5. PyTorch Models</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#fitting-neural-network-potentials">5.1. Fitting Neural Network Potentials</a></li>
<li class="toctree-l2"><a class="reference internal" href="#loss-function">5.2. Loss Function</a></li>
<li class="toctree-l2"><a class="reference internal" href="#outputs-and-error-calculation">5.3. Outputs and Error Calculation</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#error-comparison-files">5.3.1. Error/Comparison files</a></li>
<li class="toctree-l3"><a class="reference internal" href="#pytorch-model-files">5.3.2. PyTorch model files</a></li>
<li class="toctree-l3"><a class="reference internal" href="#calculate-errors-on-a-test-set">5.3.3. Calculate errors on a test set</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="#training-performance">5.4. Training Performance</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#gpu-acceleration">5.4.1. GPU Acceleration</a></li>
</ul>
</li>
</ul>
</li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Programmer Guide</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="Contributing.html">1. Contributing</a></li>
<li class="toctree-l1"><a class="reference internal" href="Executable.html">2. Executable</a></li>
<li class="toctree-l1"><a class="reference internal" href="Lib/index.html">3. Library</a></li>
<li class="toctree-l1"><a class="reference internal" href="Tests.html">4. Tests</a></li>
</ul>
</div>
</div>
</nav>
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
<a href="index.html">FitSNAP</a>
</nav>
<div class="wy-nav-content">
<div class="rst-content">
<div role="navigation" aria-label="Page navigation">
<ul class="wy-breadcrumbs">
<li><a href="index.html" class="icon icon-home"></a> »</li>
<li><span class="section-number">5. </span>PyTorch Models</li>
<li class="wy-breadcrumbs-aside">
<a href="_sources/Pytorch.rst.txt" rel="nofollow"> View page source</a>
</li>
</ul>
<hr/>
</div>
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<section id="pytorch-models">
<h1><span class="section-number">5. </span>PyTorch Models<a class="headerlink" href="#pytorch-models" title="Permalink to this heading"></a></h1>
<p>Interfacing with PyTorch allows us to conveniently fit neural network potentials using descriptors
that exist in LAMMPS. We may then use these neural network models to run high-performance MD
simulations in LAMMPS. When fitting atom-centered neural network potentials, we incorporate a
general and performant approach that allows any descriptor as input to the network. This is achieved
by pre-calculating descriptors in LAMMPS which are then fed into the network, as shown below.</p>
<figure class="align-default">
<a class="reference internal image-reference" href="_images/lammps_fitsnap_connection.png"><img alt="_images/lammps_fitsnap_connection.png" src="_images/lammps_fitsnap_connection.png" style="width: 734.4px; height: 312.0px;" /></a>
</figure>
<p>To calculate forces, we use the general chain rule expression above, where the descriptor derivatives
are analytically extracted from LAMMPS. These capabilities are further explained below.</p>
<section id="fitting-neural-network-potentials">
<h2><span class="section-number">5.1. </span>Fitting Neural Network Potentials<a class="headerlink" href="#fitting-neural-network-potentials" title="Permalink to this heading"></a></h2>
<p>Similarly to how we fit linear models, we can input descriptors into nonlinear models such as
neural networks. To do this, we can use the same FitSNAP input script that we use for linear
models, with some slight changes to the sections. First we must add a <code class="code docutils literal notranslate"><span class="pre">PYTORCH</span></code> section,
which for the tantalum example looks like:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="n">PYTORCH</span><span class="p">]</span>
<span class="n">layer_sizes</span> <span class="o">=</span> <span class="n">num_desc</span> <span class="mi">60</span> <span class="mi">60</span> <span class="mi">1</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">1.5e-4</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">save_state_output</span> <span class="o">=</span> <span class="n">Ta_Pytorch</span><span class="o">.</span><span class="n">pt</span>
<span class="n">energy_weight</span> <span class="o">=</span> <span class="mf">1e-2</span>
<span class="n">force_weight</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">training_fraction</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">multi_element_option</span> <span class="o">=</span> <span class="mi">1</span>
</pre></div>
</div>
<p>We must also add a <code class="code docutils literal notranslate"><span class="pre">nonlinear</span> <span class="pre">=</span> <span class="pre">1</span></code> key in the <code class="code docutils literal notranslate"><span class="pre">CALCULATOR</span></code> section, and set
<code class="code docutils literal notranslate"><span class="pre">solver</span> <span class="pre">=</span> <span class="pre">PYTORCH</span></code> in the <code class="code docutils literal notranslate"><span class="pre">SOLVER</span></code> section. Now the input script is ready to fit a
neural network potential.</p>
<p>The <code class="code docutils literal notranslate"><span class="pre">PYTORCH</span></code> section keys are explained in more detail below.</p>
<ul>
<li><p><code class="code docutils literal notranslate"><span class="pre">layer_sizes</span></code> determines the network architecture. We lead with a <code class="code docutils literal notranslate"><span class="pre">num_desc</span></code> parameter
which tells FitSNAP that the number of nodes in the first layer are equal to the number of
descriptors. The argument here is a list where each element determines the number of nodes in
each layer.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">learning_rate</span></code> determines how fast the network minimizes the loss function. We find that
a learning rate around <code class="code docutils literal notranslate"><span class="pre">1e-4</span></code> works well when fitting to forces, and when using our current
loss function.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">num_epochs</span></code> sets the number of gradient descent iterations.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">batch_size</span></code> determines how many configs to average gradients for when looping over batches
in a single epoch. We find that a batch size around 4 works well for our models.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">save_state_output</span></code> is the name of the PyTorch model file to write after every
epoch. This model can be loaded for testing purposes later.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">save_state_input</span></code> is the name of a PyTorch model that may be loaded for the purpose of
restarting an existing fit, or for calculating test errors.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">energy_weight</span></code> is a scalar constant multiplied by the mean squared energy error in the
loss function. Declaring this parameter will override the weights in the GROUPS section for all
configs. We therefore call this the <em>global energy weight</em>. If you want to specify energy weights
for each group, do so in the GROUPS section.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">force_weight</span></code> is a scalar constant multiplied by the mean squared force error in the loss
function. Declaring this parameter will override the weights in the GROUPS section for all
configs. We therefore call this the <em>global force weight</em>. If you want to specify force weights
for each group, do so in the GROUPS section.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">training_fraction</span></code> is a decimal fraction of how much of the total data should be trained
on. The leftover <code class="code docutils literal notranslate"><span class="pre">1.0</span> <span class="pre">-</span> <span class="pre">training_fraction</span></code> portion is used for calculating validation errors
during a fit. Declaring this parameter will override the training/testing fractions in the GROUPS
section for all configs. We therefore call this the <em>global training fraction</em>. If you want to
specify training/testing fractions for each group, do so in the GROUPS section.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">multi_element_option</span></code> is a scalar that determines how to handle multiple element types.</p>
<blockquote>
<div><ul class="simple">
<li><p>1: All element types share the same network. Descriptors may still be different per type.</p></li>
<li><p>2: Each element type has its own network.</p></li>
<li><p>3: (Coming soon) One-hot encoding of element types, where each type shares the same network.</p></li>
</ul>
</div></blockquote>
</li>
<li><p><code class="code docutils literal notranslate"><span class="pre">manual_seed_flag</span></code> set to 0 by default, can set to 1 if want to force a random seed which is
useful for debugging purposes.</p></li>
<li><p><code class="code docutils literal notranslate"><span class="pre">shuffle_flag</span></code> set to 1 by default, determines whether to shuffle the training data every epoch.</p></li>
</ul>
</section>
<section id="loss-function">
<h2><span class="section-number">5.2. </span>Loss Function<a class="headerlink" href="#loss-function" title="Permalink to this heading"></a></h2>
<p>When fitting neural network potentials we minimize the sum of weighted energy and force mean squared
errors:</p>
<div class="math notranslate nohighlight">
\[\mathcal L = \frac{1}{M} \sum_{m}^{M} \frac{1}{N_m}\{w_m^E [\hat{E}_m(\theta) - E_m]^2 + \frac{w_m^F}{3} \sum_i^{3N_m} [\hat{F}_{mi}(\theta) - F_{mi}]^2 \}\]</div>
<p>where</p>
<ul class="simple">
<li><p><span class="math notranslate nohighlight">\(M\)</span> is the number of configurations in the training set.</p></li>
<li><p><span class="math notranslate nohighlight">\(m\)</span> indexes a particular configuration.</p></li>
<li><p><span class="math notranslate nohighlight">\(N_m\)</span> is the number of atoms for configuration <span class="math notranslate nohighlight">\(m\)</span></p></li>
<li><p><span class="math notranslate nohighlight">\(w_m^E\)</span> is the energy weight of configuration <span class="math notranslate nohighlight">\(m\)</span>. These weights can be set by designating
the particular weights in the <a class="reference external" href="Run.html#groups">[GROUPS] section</a>, or by declaring a global
weight in the <code class="code docutils literal notranslate"><span class="pre">[PYTORCH]</span></code> section, which will override the group weights.</p></li>
<li><p><span class="math notranslate nohighlight">\(\theta\)</span> represents all the model fitting parameters (e.g. the trainable coefficients in a neural network).</p></li>
<li><p><span class="math notranslate nohighlight">\(\hat{E}_m(\theta)\)</span> is the model predicted energy of configuration <span class="math notranslate nohighlight">\(m\)</span></p></li>
<li><p><span class="math notranslate nohighlight">\(E_m\)</span> is the target <em>ab initio</em> energy of configuration <span class="math notranslate nohighlight">\(m\)</span>, subtracted by the LAMMPS
reference potential declared in the <a class="reference external" href="Run.html#reference">[REFERENCE] section</a>.</p></li>
<li><p><span class="math notranslate nohighlight">\(i\)</span> indexes a Cartesian index of a single atom; we lump Cartesian indices and atom indices
into a single index here.</p></li>
<li><p><span class="math notranslate nohighlight">\(w_m^F\)</span> is the force weight of configuration <span class="math notranslate nohighlight">\(m\)</span>. These weights can be set by designating
the particular weights in the <a class="reference external" href="Run.html#groups">[GROUPS] section</a>, or by declaring a global
weight in the <code class="code docutils literal notranslate"><span class="pre">[PYTORCH]</span></code> section, which will override the group weights.</p></li>
<li><p><span class="math notranslate nohighlight">\(\hat{F}_{mi}(\theta)\)</span> is a model predicted force component <span class="math notranslate nohighlight">\(i\)</span> in configuration <span class="math notranslate nohighlight">\(m\)</span></p></li>
<li><p><span class="math notranslate nohighlight">\(F_{mi}\)</span> is a target <em>ab initio</em> force component <span class="math notranslate nohighlight">\(i\)</span> in configuration <span class="math notranslate nohighlight">\(m\)</span>,
subtracted by the LAMMPS reference potential force declared in the
<a class="reference external" href="Run.html#reference">[REFERENCE] section</a>.</p></li>
</ul>
<p>This loss also gets evaluated for the validation set for each epoch, so that the screen output looks
something like:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="o">-----</span> <span class="n">epoch</span><span class="p">:</span> <span class="mi">0</span>
<span class="n">Batch</span> <span class="n">averaged</span> <span class="n">train</span><span class="o">/</span><span class="n">val</span> <span class="n">loss</span><span class="p">:</span> <span class="mf">4.002996124327183</span> <span class="mf">4.072216800280979</span>
<span class="n">Epoch</span> <span class="n">time</span> <span class="mf">0.3022959232330322</span>
<span class="o">-----</span> <span class="n">epoch</span><span class="p">:</span> <span class="mi">1</span>
<span class="n">Batch</span> <span class="n">averaged</span> <span class="n">train</span><span class="o">/</span><span class="n">val</span> <span class="n">loss</span><span class="p">:</span> <span class="mf">2.3298445120453835</span> <span class="mf">1.1800143867731094</span>
<span class="n">Epoch</span> <span class="n">time</span> <span class="mf">0.2888479232788086</span>
<span class="o">-----</span> <span class="n">epoch</span><span class="p">:</span> <span class="mi">2</span>
<span class="n">Batch</span> <span class="n">averaged</span> <span class="n">train</span><span class="o">/</span><span class="n">val</span> <span class="n">loss</span><span class="p">:</span> <span class="mf">0.6962545616552234</span> <span class="mf">0.8775447851845196</span>
<span class="n">Epoch</span> <span class="n">time</span> <span class="mf">0.26888108253479004</span>
<span class="o">-----</span> <span class="n">epoch</span><span class="p">:</span> <span class="mi">3</span>
<span class="n">Batch</span> <span class="n">averaged</span> <span class="n">train</span><span class="o">/</span><span class="n">val</span> <span class="n">loss</span><span class="p">:</span> <span class="mf">0.3671231440966949</span> <span class="mf">0.6234593641545091</span>
<span class="n">Epoch</span> <span class="n">time</span> <span class="mf">0.26917600631713867</span>
</pre></div>
</div>
<p>The first column is the weighted training set loss function, and the second column is the weighted
validation set loss function (which is not included in fitting). While the loss function units
themselves might not be meaningful for error analysis, we output model predictions and targets for
energies and forces in separate files after the fit, as explained below.</p>
</section>
<section id="outputs-and-error-calculation">
<h2><span class="section-number">5.3. </span>Outputs and Error Calculation<a class="headerlink" href="#outputs-and-error-calculation" title="Permalink to this heading"></a></h2>
<p>Unlike linear models, PyTorch models do not output statistics in a dataframe. Instead we output
energy and force comparisons in separate files, along with PyTorch models that can be used to restart
a fit or even run MD simulations in LAMMPS.</p>
<section id="error-comparison-files">
<h3><span class="section-number">5.3.1. </span>Error/Comparison files<a class="headerlink" href="#error-comparison-files" title="Permalink to this heading"></a></h3>
<p>After training a potential, FitSNAP produces outputs that can be used to intrepret the quality of a
fit on the training and/or validation data. Basic error metrics for the total set and groups are
output in the metric file declared in the <code class="code docutils literal notranslate"><span class="pre">[OUTFILE]</span></code> section:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="n">OUTFILE</span><span class="p">]</span>
<span class="n">metrics</span> <span class="o">=</span> <span class="n">Ta_metrics</span><span class="o">.</span><span class="n">dat</span> <span class="c1"># filename for Ta example</span>
</pre></div>
</div>
<p>In this example, we write error metrics to a <code class="code docutils literal notranslate"><span class="pre">Ta_metrics.dat</span></code> file.
The first line of this file describes what the columns are:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Group</span> <span class="n">Train</span><span class="o">/</span><span class="n">Test</span> <span class="n">Property</span> <span class="n">Count</span> <span class="n">MAE</span> <span class="n">RMSE</span>
<span class="o">...</span>
</pre></div>
</div>
<p>where <code class="code docutils literal notranslate"><span class="pre">Count</span></code> is the number of configurations used for energy error, or atoms used for force error.</p>
<p>Fitting progress may be tracked in the <code class="code docutils literal notranslate"><span class="pre">loss_vs_epochs.dat</span></code> file, which tracks training and validation losses.</p>
<p>More detailed fitting metrics are obtained if the following flags are declared true in the
<code class="code docutils literal notranslate"><span class="pre">[EXTRAS]</span></code> section:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="n">EXTRAS</span><span class="p">]</span>
<span class="n">dump_peratom</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># write per-atom fitting metrics</span>
<span class="n">dump_perconfig</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># write per-config fitting metrics</span>
<span class="n">dump_configs</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># write a pickled list of Configuration objects</span>
</pre></div>
</div>
<p>The following comparison files are written after a fit:</p>
<ul class="simple">
<li><p><code class="code docutils literal notranslate"><span class="pre">peratom.dat</span></code> : Fitting information for each atom, such as truth and predicted forces.</p></li>
</ul>
<p>The first line of this file describes what the columns are:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Filename</span> <span class="n">Group</span> <span class="n">AtomID</span> <span class="n">Type</span> <span class="n">Fx_Truth</span> <span class="n">Fy_Truth</span> <span class="n">Fz_Truth</span> <span class="n">Fx_Pred</span> <span class="n">Fy_Pred</span> <span class="n">Fz_Pred</span> <span class="n">Testing_Bool</span>
</pre></div>
</div>
<ul class="simple">
<li><p><code class="code docutils literal notranslate"><span class="pre">perconfig.dat</span></code> : Fitting information for each configuration, such as truth and predicted energies.</p></li>
</ul>
<p>The first line of this file describes what the columns are:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Filename</span> <span class="n">Group</span> <span class="n">Natoms</span> <span class="n">Energy_Truth</span> <span class="n">Energy_Pred</span> <span class="n">Testing_Bool</span>
</pre></div>
</div>
<ul class="simple">
<li><p><code class="code docutils literal notranslate"><span class="pre">configs.pickle</span></code> : Structural, descriptor, and fitting info for each configuration.</p></li>
</ul>
<p>This is a pickled list of <a class="reference external" href="https://github.com/FitSNAP/FitSNAP/tree/master/fitsnap3lib/tools/configuration.py">Configuration</a> objects.
Each item in the list contains all associated information of a configuration.</p>
</section>
<section id="pytorch-model-files">
<h3><span class="section-number">5.3.2. </span>PyTorch model files<a class="headerlink" href="#pytorch-model-files" title="Permalink to this heading"></a></h3>
<p>FitSNAP outputs two PyTorch <code class="code docutils literal notranslate"><span class="pre">.pt</span></code> models file after fitting. One is used for restarting a fit
based on an existing model, specifically the model name supplied by the user in the
<code class="code docutils literal notranslate"><span class="pre">save_state_output</span></code> keyword of the input script. In the <a class="reference external" href="https://github.com/FitSNAP/FitSNAP/tree/master/examples/Ta_PyTorch_NN">Ta_PyTorch_NN example</a>
we can see this keyword is <code class="code docutils literal notranslate"><span class="pre">Ta_Pytorch.pt</span></code>. This file will therefore be saved every epoch, and
it may be fed into FitSNAP via the <code class="code docutils literal notranslate"><span class="pre">save_state_input</span></code> keyword to restart another fit from that
particular model.</p>
<p>The other PyTorch model is used for running MD simulations in LAMMPS after a fit. This file has the
name <code class="code docutils literal notranslate"><span class="pre">FitTorch_Pytorch.pt</span></code>, and is used to run MD in LAMMPS via the ML-IAP package. An example
is given for tantalum here: <a class="reference external" href="https://github.com/FitSNAP/FitSNAP/tree/master/examples/Ta_PyTorch_NN/MD">https://github.com/FitSNAP/FitSNAP/tree/master/examples/Ta_PyTorch_NN/MD</a></p>
</section>
<section id="calculate-errors-on-a-test-set">
<h3><span class="section-number">5.3.3. </span>Calculate errors on a test set<a class="headerlink" href="#calculate-errors-on-a-test-set" title="Permalink to this heading"></a></h3>
<p>Users may want to use models to calculate errors on a test set that was completely separate from the
training/validation sets used in fitting. To do this, we change the input script to read an existing
PyTorch model file, e.g. for Ta:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="p">[</span><span class="n">PYTORCH</span><span class="p">]</span>
<span class="n">layer_sizes</span> <span class="o">=</span> <span class="n">num_desc</span> <span class="mi">60</span> <span class="mi">60</span> <span class="mi">1</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">1.5e-4</span>
<span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1">##### Set to 1 for calculating test errors</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">save_state_input</span> <span class="o">=</span> <span class="n">Ta_Pytorch</span><span class="o">.</span><span class="n">pt</span> <span class="c1">##### Load an existing model</span>
<span class="n">energy_weight</span> <span class="o">=</span> <span class="mf">1e-2</span>
<span class="n">force_weight</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">training_fraction</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">multi_element_option</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">num_elements</span> <span class="o">=</span> <span class="mi">1</span>
</pre></div>
</div>
<p>Notice how we are now using <code class="code docutils literal notranslate"><span class="pre">save_state_input</span></code> instead of <code class="code docutils literal notranslate"><span class="pre">save_state_output</span></code>, and that
we set <code class="code docutils literal notranslate"><span class="pre">num_epochs</span> <span class="pre">=</span> <span class="pre">1</span></code>. This will load the existing PyTorch model, and perform a single epoch
which involves calculating the energy and force comparisons (mentioned above) for the current model,
on whatever user-defined groups of configs in the groups section.We can therefore use the energy and
force comparison files here to calculate mean absolute errors, e.g. with the script in
the <a class="reference external" href="https://github.com/FitSNAP/FitSNAP/tree/master/examples/Ta_PyTorch_NN">Ta_PyTorch_NN example</a></p>
</section>
</section>
<section id="training-performance">
<h2><span class="section-number">5.4. </span>Training Performance<a class="headerlink" href="#training-performance" title="Permalink to this heading"></a></h2>
<p>As seen in the <code class="code docutils literal notranslate"><span class="pre">Ta_Pytorch_NN</span></code> example, fitting to ~300 configs (each with ~12 atoms) takes
about ~0.2 s/epoch. The number of epochs required, and therefore total time of your fit, will depend
on the size of your dataset <em>and</em> the <code class="code docutils literal notranslate"><span class="pre">batch_size</span></code>. For example, the <code class="code docutils literal notranslate"><span class="pre">Ta_Pytorch_NN</span></code> example
might take ~200 epochs to fully converge (see <code class="code docutils literal notranslate"><span class="pre">loss_vs_epochs.dat</span></code>). In this example, however,
we used <code class="code docutils literal notranslate"><span class="pre">batch_size=4</span></code>, meaning that each epoch involved <code class="code docutils literal notranslate"><span class="pre">~300/4</span> <span class="pre">=</span> <span class="pre">~75</span></code> gradient descent
minimizations as we cycled through batches. For much larger datasets, the network will experience
more cycles through the batches with each epoch, and therefore may require less epochs to reach
the same convergence.</p>
<p>For data sets of ~10,000 configs and ~50 atoms per config, training will take ~1 hour, or about
20 seconds per epoch. This can consume about ~20 GB of RAM.</p>
<p>Computational scaling is roughly <code class="code docutils literal notranslate"><span class="pre">O(num_atoms*num_neighs)</span></code> where <code class="code docutils literal notranslate"><span class="pre">num_atoms</span></code> is the
total number of atoms in the training set, and <code class="code docutils literal notranslate"><span class="pre">num_neighs</span></code> is the average number of neighbors
per atom.</p>
<p>Mini-batch network training is embarassingly parallel up to the batch size, but currently FitSNAP
does not support parallelized NN training.</p>
<section id="gpu-acceleration">
<h3><span class="section-number">5.4.1. </span>GPU Acceleration<a class="headerlink" href="#gpu-acceleration" title="Permalink to this heading"></a></h3>
<p>FitSNAP supports GPU acceleration via PyTorch. With small batch sizes, however, most of the benefit
of GPU parallelization comes from evaluating the NN model and calculating gradients. You will not see
a large benefit of GPUs using a small batch size unless you have a large NN model (e.g. > 1 million
parameters). If you have a small model, you will see a speedup on GPUs using a large enough batch
size.</p>
</section>
</section>
</section>
</div>
</div>
<footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
<a href="Linear.html" class="btn btn-neutral float-left" title="4. Linear Models" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
<a href="Contributing.html" class="btn btn-neutral float-right" title="1. Contributing" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right" aria-hidden="true"></span></a>
</div>
<hr/>
<div role="contentinfo">
<p>© Copyright 2022, Sandia Corporation.</p>
</div>
Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
<a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script>
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
</body>
</html>