diff --git a/lib/neuralnetwork.js b/lib/neuralnetwork.js index 02fae0d..d7ff6ba 100644 --- a/lib/neuralnetwork.js +++ b/lib/neuralnetwork.js @@ -88,6 +88,7 @@ NeuralNetwork.prototype = { var learningRate = options.learningRate || this.learningRate || 0.3; var callback = options.callback; var callbackPeriod = options.callbackPeriod || 10; + var initialization = _.isUndefined(options.initialization) ? true : options.initialization; var inputSize = data[0].input.length; var outputSize = data[0].output.length; @@ -97,7 +98,9 @@ NeuralNetwork.prototype = { hiddenSizes = [Math.max(3, Math.floor(inputSize / 2))]; } var sizes = _([inputSize, hiddenSizes, outputSize]).flatten(); - this.initialize(sizes); + if (initialization) { + this.initialize(sizes); + } var error = 1; for (var i = 0; i < iterations && error > errorThresh; i++) { diff --git a/test/unit/trainopts.js b/test/unit/trainopts.js index 5ca84e7..9d09376 100644 --- a/test/unit/trainopts.js +++ b/test/unit/trainopts.js @@ -48,4 +48,18 @@ describe('train() options', function() { callbackPeriod: 20 }); }); + + it('trains the neural network without (re)initializing the weights', function() { + var net = new brain.NeuralNetwork(); + var firstTrainingError = net.train(data, { + errorThresh: 0.2, + iterations: 100000 + }).error; + var secondTrainingError = net.train(data, { + errorThresh: 0.2, + iterations: 1, + initialization: false + }).error; + assert.ok(secondTrainingError <= firstTrainingError); + }); })