diff --git a/math_demo.html b/math_demo.html new file mode 100644 index 0000000..6742f1b --- /dev/null +++ b/math_demo.html @@ -0,0 +1,719 @@ + + +RecurrentJS Math Demo + + + + + + + + + + + + + + + +
+Fork me on GitHub + + +
+

Deep Recurrent Nets math demo

+
+ This demo shows usage of the recurrentjs library that allows you to train deep Recurrent Neural Networks (RNN) and Long Short-Term Memory Networks (LSTM) in Javascript. But the core of the library is more general and allows you to set up arbitrary expression graphs that support fully automatic backpropagation.

+ + In this demo we take a dataset of random math characters as input and learn to memorize the math logic character by character. That is, the RNN/LSTM takes a character, its context from previous time steps (as mediated by the hidden layers) and predicts the next character in the sequence. Here is an example:

+ +
+ + In the example image above that depicts a deep RNN, every character has an associated "letter vector" that we will train with backpropagation. These letter vectors are combined through a (learnable) Matrix-vector multiply transformation into the first hidden layer representation (yellow), then into second hidden layer representation (purple), and finally into the output space (blue). The output space has dimensionality equal to the number of characters in the dataset and every dimension provides the probability of the next character in the sequence. The network is therefore trained to always predict the next character (using Softmax + cross-entropy loss on all letters). The quantity we track during training is called the perplexity, which measures how surprised the network is to see the next character in a sequence. For example, if perplexity is 4.0 then it's as if the network was guessing uniformly at random from 4 possible characters for next letter (i.e. lowest it can be is 1). At test time, the prediction is currently done iteratively character by character in a greedy fashion, but I might eventually implemented more sophisticated methods (e.g. beam search).

+ + The demo is populated with random math from javascript.

+ + For suggestions/bugs ping me at @karpathy.

+ +
+
+
Input sentences:
+ +
+
+ +
Controls/Options:
+ + + + +
+ protip: if your perplexity is exploding with Infinity try lowering the initial learning rate +
+
+ +
+
Training stats:
+
+
Learning rate: you want to anneal this over time if you're training for longer time.
+
+
+
+ + +
+
+
+
+ +
+
+ +
Model samples:
+
+
+
Softmax sample temperature: lower setting will generate more likely predictions, but you'll see more of the same common words again and again. Higher setting will generate less frequent words but you might see more spelling errors.
+
+
+
+
+
+
Greedy argmax prediction:
+
+
+
+
I/O save/load model JSON
+ + + +
+ You can save or load models with JSON using the textarea below. +
+ + +
+
Pretrained model:
+ You can also choose to load an example pretrained model with the button below to see what the predictions look like in later stages. The pretrained model is an LSTM with one layer of 100 units, trained for ~10 hours. After clicking button below you should see the perplexity plummet to about 3.0, and see the predictions become better.
+ + +
+
+ + + + + diff --git a/rnn-viewer.js b/rnn-viewer.js new file mode 100644 index 0000000..558511f --- /dev/null +++ b/rnn-viewer.js @@ -0,0 +1,245 @@ +function RNNViewer(settings) { + Object.assign(this, RNNViewer.defaults, settings); + + this.net = settings.net; + this.boundingGrid = null; + this.values = []; + this.grids = []; + this.matrices = []; + this.controls = null; + this.scene = null; + this.camera = null; + this.light = null; + this.renderer = null; + this.stats = null; + + this.init(); + + if (this.net) { + var model = this.net.model; + var addMatrix = this.addMatrix.bind(this); + + addMatrix(model.input); + + model.hiddenLayers.forEach(function(hiddenLayer) { + for (var p in hiddenLayer) { + if (!hiddenLayer.hasOwnProperty(p)) continue; + addMatrix(hiddenLayer[p]); + } + }); + + addMatrix(model.outputConnector); + addMatrix(model.output); + } + + this.animate(); +} + +RNNViewer.defaults = { + net: null, + container: null, + height: window.innerHeight, + width: window.innerWidth, + depth: 100, + hotColor: new THREE.Color(0xff55f9), + coldColor: new THREE.Color(0x050638), + squareWidth: 10, + squareHeight: 10, + devicePixelRatio: window.devicePixelRatio, + includeStats: false +}; + +RNNViewer.prototype = { + init: function() { + //Set up camera + var vFOVRadians = 2 * Math.atan(this.height / (2 * 1500)), + fov = vFOVRadians * 180 / Math.PI, + startPosition = this.startPosition = new THREE.Vector3(0, 0, 3000); + + var camera = this.camera = new THREE.PerspectiveCamera(fov, this.width / this.height, 1, 30000); + camera.position.set(startPosition.x, startPosition.y, startPosition.z); + + var controls = this.controls = new THREE.OrbitControls(camera); + controls.damping = 0.2; + controls.addEventListener('change', this.render.bind(this)); + + //Create scenes for webGL + var scene = this.scene = new THREE.Scene(); + //Add a light source & create Canvas + var light = this.light = new THREE.DirectionalLight( 0xffffff ); + light.position.set(0, 0, 1); + scene.add(light); + + //set up webGL renderer + var renderer = this.renderer = new THREE.WebGLRenderer(); + renderer.setPixelRatio(this.devicePixelRatio); + renderer.setSize(this.width, this.height); + this.container.appendChild(renderer.domElement); + + //stats + if (this.includeStats) { + var stats = this.stats = new Stats(); + stats.domElement.style.position = 'absolute'; + stats.domElement.style.bottom = '10px'; + stats.domElement.style.left = '10px'; + this.container.appendChild(stats.domElement); + } + + var boundingGrid = this.boundingGrid = new THREE.Object3D(); + scene.add(boundingGrid); + return this; + }, + update: function() { + var hotColor = this.settings.hotColor; + var coldColor = this.settings.coldColor; + return this; + }, + render: function() { + var depth = this.depth; + this.grids.forEach(function(grid, i, grids) { + grid.position.z = (grids.length - i) * depth; + }); + + this.camera.lookAt(this.scene.position); + this.renderer.render(this.scene, this.camera); + if (this.stats) this.stats.update(); + return this; + }, + animate: function() { + this.controls.update(); + window.requestAnimationFrame(this.animate.bind(this)); + return this; + }, + addMatrix: function (matrix) { + var grid = new THREE.Object3D(), + depth = this.depth, + rows = matrix.rows, + columns = matrix.columns, + xPixel = -(this.squareWidth * columns)/ 2, + yPixel = -(this.squareHeight * rows) / 2, + lowValue = 0, + highValue = 0, + index = 0; + + //height + for (var row = 1; row <= rows; row++) { + xPixel = -(this.squareWidth * columns) / 2; + for (var column = 1; column <= columns; column++) { + var color = this.coldColor.clone(); + var material = new THREE.MeshBasicMaterial({ + color: color, + side: THREE.DoubleSide, + vertexColors: THREE.FaceColors + }); + var square = new THREE.Geometry(); + square.vertices.push(new THREE.Vector3(xPixel , yPixel , 0)); + square.vertices.push(new THREE.Vector3(xPixel , yPixel + this.squareHeight , 0)); + square.vertices.push(new THREE.Vector3(xPixel + this.squareWidth, yPixel + this.squareHeight , 0)); + square.vertices.push(new THREE.Vector3(xPixel + this.squareWidth, yPixel , 0)); + + square.faces.push(new THREE.Face3(0, 1, 2)); + square.faces.push(new THREE.Face3(0, 3, 2)); + var mesh = new THREE.Mesh(square, material); + grid.add(mesh); + + this.values.push({ + color: color, + row: row - 1, + column: column - 1, + matrixIndex: this.grids.length, + square: square, + mesh: mesh, + frontFace: mesh.geometry.faces[0], + rearFace: mesh.geometry.faces[1], + index: index, + matrix: matrix, + get value() { + var value = this.matrix.weights[this.index]; + if (value > highValue) { + highValue = value; + } + if (value < lowValue) { + lowValue = value; + } + return value || 0; + }, + get percentValue() { + var value = this.value; + var normalizedHigh = highValue - lowValue; + var normalizedValue = value - lowValue; + return (normalizedHigh - normalizedValue) / normalizedHigh; + } + }); + + xPixel += this.squareWidth; + index++; + } + yPixel += this.squareHeight; + } + + this.grids.push(grid); + this.matrices.push(matrix); + this.boundingGrid.add(grid); + + return this; + }, + viewTop: function() { + this.controls.reset(); + + var vFOVRadians = 2 * Math.atan(this.height / ( 2 * 35000 )), + fov = vFOVRadians * 180 / Math.PI; + + this.camera.fov = fov; + this.controls.rotateUp(90 * Math.PI / 180); + this.camera.position.z = this.startPosition.z * 23; + this.camera.position.y = this.startPosition.z * 55; + this.camera.far = 1000000; + this.camera.updateProjectionMatrix(); + return this.render(); + }, + viewSide: function() { + this.controls.reset(); + + var vFOVRadians = 2 * Math.atan(this.height / ( 2 * 35000 )), + fov = vFOVRadians * 180 / Math.PI; + + this.camera.fov = fov; + this.camera.position.z = this.startPosition.z * 58; + this.camera.far = 1000000; + this.camera.updateProjectionMatrix(); + return this.render(); + }, + viewDefault: function() { + this.controls.reset(); + + this.camera.fov = 30; + this.camera.updateProjectionMatrix(); + return this.render(); + }, + setSize: function(width, height) { + this.width = width; + this.height = height; + this.renderer.setSize(this.width, this.height); + return this.render(); + }, + setValue: function(v) { + var v = Math.random() * 2, + r = (coldColor.r + hotColor.r) / v, + g = (coldColor.g + hotColor.g) / v, + b = (coldColor.b + hotColor.b) / v; + + value.frontFace.color.setRGB( + r, + g, + b + ); + value.rearFace.color.setRGB( + r, + g, + b + ); + value.square.colorsNeedUpdate = true; + //value.mesh.geometry.elementsNeedUpdate = true; + value.mesh.geometry.colorsNeedUpdate = true; + } +}; \ No newline at end of file diff --git a/src/recurrent.js b/src/recurrent.js index 4c20a39..221d4af 100644 --- a/src/recurrent.js +++ b/src/recurrent.js @@ -1,6 +1,6 @@ var R = {}; // the Recurrent library -(function(global) { +(function (global) { "use strict"; // Utility fun @@ -18,341 +18,417 @@ var R = {}; // the Recurrent library // Random numbers utils var return_v = false; var v_val = 0.0; - var gaussRandom = function() { - if(return_v) { + var gaussRandom = function () { + if (return_v) { return_v = false; - return v_val; + return v_val; } - var u = 2*Math.random()-1; - var v = 2*Math.random()-1; - var r = u*u + v*v; - if(r == 0 || r > 1) return gaussRandom(); - var c = Math.sqrt(-2*Math.log(r)/r); - v_val = v*c; // cache this + var u = 2 * Math.random() - 1; + var v = 2 * Math.random() - 1; + var r = u * u + v * v; + if (r == 0 || r > 1) return gaussRandom(); + var c = Math.sqrt(-2 * Math.log(r) / r); + v_val = v * c; // cache this return_v = true; - return u*c; - } - var randf = function(a, b) { return Math.random()*(b-a)+a; } - var randi = function(a, b) { return Math.floor(Math.random()*(b-a)+a); } - var randn = function(mu, std){ return mu+gaussRandom()*std; } + return u * c; + }; + var randf = function (a, b) { + return Math.random() * (b - a) + a; + }; + var randi = function (a, b) { + return Math.floor(Math.random() * (b - a) + a); + }; + var randn = function (mu, std) { + return mu + gaussRandom() * std; + }; // helper function returns array of zeros of length n // and uses typed arrays if available - var zeros = function(n) { - if(typeof(n)==='undefined' || isNaN(n)) { return []; } - if(typeof ArrayBuffer === 'undefined') { + var zeros = function (n) { + if (typeof(n) === 'undefined' || isNaN(n)) { + return []; + } + if (typeof ArrayBuffer === 'undefined') { // lacking browser support var arr = new Array(n); - for(var i=0;i= 0 && ix < this.w.length); return this.w[ix]; }, - set: function(row, col, v) { + set: function (row, col, v) { // slow but careful accessor function var ix = (this.d * row) + col; assert(ix >= 0 && ix < this.w.length); - this.w[ix] = v; + this.w[ix] = v; }, - toJSON: function() { + toJSON: function () { var json = {}; json['n'] = this.n; json['d'] = this.d; json['w'] = this.w; return json; }, - fromJSON: function(json) { + fromJSON: function (json) { this.n = json.n; this.d = json.d; this.w = zeros(this.n * this.d); this.dw = zeros(this.n * this.d); - for(var i=0,n=this.n * this.d;i=0;i--) { + backward: function () { + for (var i = this.backprop.length - 1; i >= 0; i--) { this.backprop[i](); // tick! } }, - rowPluck: function(m, ix) { + rowPluck: function (m, ix) { // pluck a row of m with index ix and return it as col vector assert(ix >= 0 && ix < m.n); var d = m.d; var out = new Mat(d, 1); - for(var i=0,n=d;i 0 ? out.dw[i] : 0.0; } - } + }; this.backprop.push(backward); } return out; }, - mul: function(m1, m2) { + mul: function (m1, m2) { // multiply matrices m1 * m2 assert(m1.d === m2.n, 'matmul dimensions misaligned'); var n = m1.n; var d = m2.d; - var out = new Mat(n,d); - for(var i=0;i maxval) maxval = m.w[i]; } + function dwListen(m) { + if (m._dw) return; + m._dw = m.dw; + m.dw = { + length: m._dw.length + }; + m._dw.forEach(function (value, i) { + (function (i) { + m.dw.__defineSetter__(i.toString(), function (value) { + m._dw[i] = value; + }); + m.dw.__defineGetter__(i.toString(), function () { + return m._dw[i]; + }); + })(i) + }); + } - var s = 0.0; - for(var i=0,n=m.w.length;i maxval) maxval = m.w[i]; + } - // no backward pass here needed - // since we will use the computed probabilities outside - // to set gradients directly on m - return out; + var s = 0.0; + for (var i = 0, n = m.w.length; i < n; i++) { + out.w[i] = Math.exp(m.w[i] - maxval); + s += out.w[i]; } + for (var i = 0, n = m.w.length; i < n; i++) { + out.w[i] /= s; + } + + // no backward pass here needed + // since we will use the computed probabilities outside + // to set gradients directly on m + return out; + }; - var Solver = function() { + var Solver = function () { this.decay_rate = 0.999; this.smooth_eps = 1e-8; this.step_cache = {}; - } + }; Solver.prototype = { - step: function(model, step_size, regc, clipval) { + step: function (model, step_size, regc, clipval) { // perform parameter update var solver_stats = {}; var num_clipped = 0; var num_tot = 0; - for(var k in model) { - if(model.hasOwnProperty(k)) { + for (var k in model) { + if (model.hasOwnProperty(k)) { var m = model[k]; // mat ref - if(!(k in this.step_cache)) { this.step_cache[k] = new Mat(m.n, m.d); } + if (!(k in this.step_cache)) { + this.step_cache[k] = new Mat(m.n, m.d); + } var s = this.step_cache[k]; - for(var i=0,n=m.w.length;i clipval) { + if (mdwi > clipval) { mdwi = clipval; num_clipped++; } - if(mdwi < -clipval) { + if (mdwi < -clipval) { mdwi = -clipval; num_clipped++; } num_tot++; // update (and regularize) - m.w[i] += - step_size * mdwi / Math.sqrt(s.w[i] + this.smooth_eps) - regc * m.w[i]; + m.w[i] += -step_size * mdwi / Math.sqrt(s.w[i] + this.smooth_eps) - regc * m.w[i]; m.dw[i] = 0; // reset gradients for next iteration } } } - solver_stats['ratio_clipped'] = num_clipped*1.0/num_tot; + solver_stats['ratio_clipped'] = num_clipped * 1.0 / num_tot; return solver_stats; } - } + }; - var initLSTM = function(input_size, hidden_sizes, output_size) { + var initLSTM = function (input_size, hidden_sizes, output_size) { // hidden size should be a list var model = {}; - for(var d=0;d