Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 4f0de84

Browse files
cghawthornedsmilkov
authored andcommitted
Add LSTM ops and a simple LSTM demo (#47)
* BasicLSTMCell * cleanup * 2-layer lstm * add training script for pi digits * start * interface * docs * lint * lint * fixes * updates based on comments
1 parent 49b71ad commit 4f0de84

13 files changed

+452
-0
lines changed

demos/lstm/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Learning digits of pi using an LSTM
2+
3+
Demonstrates training a simple autoregressive LSTM network in Tensorflow and
4+
then porting that model to deeplearn.js.
5+
6+
This network uses two ``BasicLSTMCell``s combined with ``MultiRNNCell``. The
7+
network is trained to memorize the first few digits of pi.
8+
9+
First, train the LSTM network with Tensorflow:
10+
11+
```
12+
python demos/lstm/train.py
13+
```
14+
15+
Next, export the weights to be used by deeplearn.js:
16+
17+
```
18+
python scripts/dump_checkpoint_vars.py --output_dir=demos/lstm/ --checkpoint_file=/tmp/simple_lstm-1000 --remove_variables_regex=".*Adam.*|.*beta.*"
19+
```
20+
21+
Finally, start the demo:
22+
23+
```
24+
scripts/watch-demo demos/lstm/lstm.ts
25+
```

demos/lstm/fully_connected_biases

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
,:M�P�>Ƴ�mZK�fN<>��ֻ�G��d���6��H�

demos/lstm/fully_connected_weights

800 Bytes
Binary file not shown.

demos/lstm/index.html

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<!-- Copyright 2017 Google Inc. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.
11+
==============================================================================-->
12+
<html>
13+
<head>
14+
<title>LSTM Demo</title>
15+
</head>
16+
<body>
17+
<h1>LSTM demo</h1>
18+
<div>Expected:</div>
19+
<div id="expected"></div>
20+
<div>Results:</div>
21+
<div id="results"></div>
22+
<div id="success"></div>
23+
<script src="bundle.js"></script>
24+
</body>
25+
</html>

demos/lstm/lstm_inference.ts

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright 2017 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
import {Array1D, Array2D, CheckpointLoader, NDArrayMathGPU, Scalar,
17+
util} from '../deeplearnjs';
18+
19+
// manifest.json lives in the same directory.
20+
const reader = new CheckpointLoader('.');
21+
reader.getAllVariables().then(vars => {
22+
const primerData = 3;
23+
const expected = [1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4];
24+
const math = new NDArrayMathGPU();
25+
26+
const lstmKernel1 = vars[
27+
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel'] as Array2D;
28+
const lstmBias1 = vars[
29+
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias'] as Array1D;
30+
31+
const lstmKernel2 = vars[
32+
'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel'] as Array2D;
33+
const lstmBias2 = vars[
34+
'rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias'] as Array1D;
35+
36+
const fullyConnectedBiases = vars['fully_connected/biases'] as Array1D;
37+
const fullyConnectedWeights = vars['fully_connected/weights'] as Array2D;
38+
39+
const results: number[] = [];
40+
41+
math.scope((keep, track) => {
42+
const forgetBias = track(Scalar.new(1.0));
43+
const lstm1 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel1,
44+
lstmBias1);
45+
const lstm2 = math.basicLSTMCell.bind(math, forgetBias, lstmKernel2,
46+
lstmBias2);
47+
48+
let c = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])),
49+
track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))];
50+
let h = [track(Array2D.zeros([1, lstmBias1.shape[0] / 4])),
51+
track(Array2D.zeros([1, lstmBias2.shape[0] / 4]))];
52+
53+
let input = primerData;
54+
for (let i = 0; i < expected.length; i++) {
55+
const onehot = track(Array2D.zeros([1, 10]));
56+
onehot.set(1.0, 0, input);
57+
58+
const output = math.multiRNNCell([lstm1, lstm2], onehot, c, h);
59+
60+
c = output[0];
61+
h = output[1];
62+
63+
const outputH = h[1];
64+
const weightedResult = math.matMul(outputH, fullyConnectedWeights);
65+
const logits = math.add( weightedResult, fullyConnectedBiases);
66+
67+
const result = math.argMax(logits).get();
68+
results.push(result);
69+
input = result;
70+
}
71+
});
72+
document.getElementById('expected').innerHTML = '' + expected;
73+
document.getElementById('results').innerHTML = '' + results;
74+
if(util.arraysEqual(expected, results)) {
75+
document.getElementById('success').innerHTML = 'Success!';
76+
} else {
77+
document.getElementById('success').innerHTML = 'Failure.';
78+
}
79+
});

demos/lstm/manifest.json

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"fully_connected/biases": {
3+
"filename": "fully_connected_biases",
4+
"shape": [
5+
10
6+
]
7+
},
8+
"fully_connected/weights": {
9+
"filename": "fully_connected_weights",
10+
"shape": [
11+
20,
12+
10
13+
]
14+
},
15+
"rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias": {
16+
"filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_bias",
17+
"shape": [
18+
80
19+
]
20+
},
21+
"rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel": {
22+
"filename": "rnn_multi_rnn_cell_cell_0_basic_lstm_cell_kernel",
23+
"shape": [
24+
30,
25+
80
26+
]
27+
},
28+
"rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias": {
29+
"filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_bias",
30+
"shape": [
31+
80
32+
]
33+
},
34+
"rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel": {
35+
"filename": "rnn_multi_rnn_cell_cell_1_basic_lstm_cell_kernel",
36+
"shape": [
37+
40,
38+
80
39+
]
40+
}
41+
}
320 Bytes
Binary file not shown.
9.38 KB
Binary file not shown.
320 Bytes
Binary file not shown.
12.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)