diff --git a/src/main/java/com/dj/core/model/graph/ConnectedNeuron.java b/src/main/java/com/dj/core/model/graph/ConnectedNeuron.java index abe2502..e9d457e 100644 --- a/src/main/java/com/dj/core/model/graph/ConnectedNeuron.java +++ b/src/main/java/com/dj/core/model/graph/ConnectedNeuron.java @@ -170,6 +170,24 @@ public void backwardSignalReceived(final Double error) { final var dzLearningRate = dz * context.getLearningRate(); backwardConnections = backwardConnections.add(inputSignals.scalarMultiply(dzLearningRate)); + if (context.getRegularizationRate() != 0.) { + backwardConnections.walkInColumnOrder(new RealMatrixChangingVisitor() { + @Override + public void start(final int i, final int i1, final int i2, final int i3, final int i4, final int i5) { + + } + + @Override + public double visit(final int i, final int i1, final double v) { + return v - Math.pow(v, context.getRegularizationLevel()) * context.getRegularizationRate(); + } + + @Override + public double end() { + return 0; + } + }); + } bias.addAndGet(inputSignalsAverage * dz * context.getLearningRate()); neuronIndexes diff --git a/src/main/java/com/dj/core/model/graph/Context.java b/src/main/java/com/dj/core/model/graph/Context.java index aef3342..fdb91b4 100644 --- a/src/main/java/com/dj/core/model/graph/Context.java +++ b/src/main/java/com/dj/core/model/graph/Context.java @@ -8,6 +8,10 @@ public class Context implements Serializable { private boolean debugMode; + private int regularizationLevel = 2; + + private double regularizationRate = 0.; + public Context(final double learningRate, final boolean debugMode) { this.learningRate = learningRate; this.debugMode = debugMode; @@ -33,4 +37,20 @@ public boolean isDebugMode() { public void setDebugMode(final boolean debugMode) { this.debugMode = debugMode; } + + public int getRegularizationLevel() { + return regularizationLevel; + } + + public void setRegularizationLevel(final int regularizationLevel) { + this.regularizationLevel = regularizationLevel; + } + + public double getRegularizationRate() { + return regularizationRate; + } + + public void setRegularizationRate(final double regularizationRate) { + this.regularizationRate = regularizationRate; + } }