diff --git a/src/convnet_magicnet.js b/src/convnet_magicnet.js index 7f3c159c..a60d536b 100644 --- a/src/convnet_magicnet.js +++ b/src/convnet_magicnet.js @@ -50,6 +50,10 @@ this.momentum_max = getopt(opt, 'momentum_max', 0.9); this.neurons_min = getopt(opt, 'neurons_min', 5); this.neurons_max = getopt(opt, 'neurons_max', 30); + + this.input_out_sx = getopt(opt, 'input_out_sx', 1); + this.input_out_sy = getopt(opt, 'input_out_sy', 1); + this.input_out_depth = getopt(opt, 'input_out_depth', this.data[0].w.length); // computed this.folds = []; // data fold indices, gets filled by sampleFolds() @@ -85,12 +89,11 @@ // returns a random candidate network sampleCandidate: function() { - var input_depth = this.data[0].w.length; var num_classes = this.unique_labels.length; // sample network topology and hyperparameters var layer_defs = []; - layer_defs.push({type:'input', out_sx:1, out_sy:1, out_depth: input_depth}); + layer_defs.push({type: 'input', out_sx: this.input_out_sx, out_sy: this.input_out_sy, out_depth: this.input_out_depth}); var nl = weightedSample([0,1,2,3], [0.2, 0.3, 0.3, 0.2]); // prefer nets with 1,2 hidden layers for(var q=0;q