Skip to content

Commit 8ff74b1

Browse files
committed
avoid casting to Connection
1 parent 4c4a9a3 commit 8ff74b1

File tree

3 files changed

+22
-22
lines changed

3 files changed

+22
-22
lines changed

nature-of-code/xor/src/nn/Connection.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ public Connection(Neuron a_, Neuron b_, double w) {
2626
weight = w;
2727
}
2828

29-
public Neuron getFrom() {
29+
public Neuron from() {
3030
return from;
3131
}
3232

33-
public Neuron getTo() {
33+
public Neuron to() {
3434
return to;
3535
}
3636

37-
public double getWeight() {
37+
public double weight() {
3838
return weight;
3939
}
4040

nature-of-code/xor/src/nn/Network.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public double feedForward(double[] inputVals) {
8686
output.calcOutput();
8787

8888
// Return output
89-
return output.getOutput();
89+
return output.output();
9090
}
9191

9292
public double train(double[] inputs, double answer) {
@@ -101,11 +101,11 @@ public double train(double[] inputs, double answer) {
101101
// BACKPROPOGATION
102102
// This is easier b/c we just have one output
103103
// Apply Delta to connections between hidden and output
104-
ArrayList connections = output.getConnections();
104+
ArrayList<Connection> connections = output.getConnections();
105105
for (int i = 0; i < connections.size(); i++) {
106-
Connection c = (Connection) connections.get(i);
107-
Neuron neuron = c.getFrom();
108-
double loutput = neuron.getOutput();
106+
Connection c = connections.get(i);
107+
Neuron neuron = c.from();
108+
double loutput = neuron.output();
109109
double deltaWeight = loutput*deltaOutput;
110110
c.adjustWeight(LEARNING_CONSTANT*deltaWeight);
111111
}
@@ -116,24 +116,24 @@ public double train(double[] inputs, double answer) {
116116
double sum = 0;
117117
// Sum output delta * hidden layer connections (just one output)
118118
for (int j = 0; j < connections.size(); j++) {
119-
Connection c = (Connection) connections.get(j);
119+
Connection c = connections.get(j);
120120
// Is this a connection from hidden layer to next layer (output)?
121-
if (c.getFrom() == hidden1) {
122-
sum += c.getWeight()*deltaOutput;
121+
if (c.from() == hidden1) {
122+
sum += c.weight()*deltaOutput;
123123
}
124124
}
125125
// Then adjust the weights coming in based:
126126
// Above sum * derivative of sigmoid output function for hidden neurons
127127
for (int j = 0; j < connections.size(); j++) {
128-
Connection c = (Connection) connections.get(j);
128+
Connection c = connections.get(j);
129129
// Is this a connection from previous layer (input) to hidden layer?
130-
if (c.getTo() == hidden1) {
131-
double loutput = hidden1.getOutput();
130+
if (c.to() == hidden1) {
131+
double loutput = hidden1.output();
132132
double deltaHidden = loutput * (1 - loutput); // Derivative of sigmoid(x)
133133
deltaHidden *= sum; // Would sum for all outputs if more than one output
134-
Neuron neuron = c.getFrom();
135-
double deltaWeight = neuron.getOutput()*deltaHidden;
136-
c.adjustWeight(LEARNING_CONSTANT*deltaWeight);
134+
Neuron neuron = c.from();
135+
double deltaWeight = neuron.output() * deltaHidden;
136+
c.adjustWeight(LEARNING_CONSTANT * deltaWeight);
137137
}
138138
}
139139
}

nature-of-code/xor/src/nn/Neuron.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,17 @@ public void calcOutput() {
4141
//System.out.println("Looking through " + connections.size() + " connections");
4242
for (int i = 0; i < connections.size(); i++) {
4343
Connection c = connections.get(i);
44-
Neuron from = c.getFrom();
45-
Neuron to = c.getTo();
44+
Neuron from = c.from();
45+
Neuron to = c.to();
4646
// Is this connection moving forward to us
4747
// Ignore connections that we send our output to
4848
if (to == this) {
4949
// This isn't really necessary
5050
// But I am treating the bias individually in case I need to at some point
5151
if (from.bias) {
52-
lbias = from.getOutput()*c.getWeight();
52+
lbias = from.output()*c.weight();
5353
} else {
54-
sum += from.getOutput()*c.getWeight();
54+
sum += from.output()*c.weight();
5555
}
5656
}
5757
}
@@ -64,7 +64,7 @@ void addConnection(Connection c) {
6464
connections.add(c);
6565
}
6666

67-
double getOutput() {
67+
double output() {
6868
return output;
6969
}
7070

0 commit comments

Comments
 (0)