-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnode.py
More file actions
65 lines (50 loc) · 1.99 KB
/
node.py
File metadata and controls
65 lines (50 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
import numpy as np
import random
import activation
from trainable import Weight, Bias
class Node():
pass
class Neuron(Node):
def __init__(self, numConnections = None, activationFunc = None):
self.weights = Weight(numConnections)
self.bias = Bias()
self.input = None
self.activation = 0
if (activationFunc != None):
self.activationFunction = activationFunc
else:
self.activationFunction = activation.Relu()
def reset(self, numConnections, activationFunc):
self.weights.reset(numConnections)
self.bias.reset()
self.input = None
self.activation = 0
if (activationFunc != None):
self.activationFunction = activationFunc
else:
self.activationFunction = activation.Relu()
def evaluate(self, input: np.ndarray):
assert (input.shape == self.weights.shape())
self.input = input
weightedSum = (self.weights.dot(input)) + self.bias
activation = self.activationFunction.evaluate(weightedSum)
self.activation = activation
return activation
def setActivation(self, activationFunc: activation.ActivationFunction):
self.activationFunction = activationFunc
def activationFunctionDerivative(self):
'''
may break if there has been no function calculation before derivative calculation
'''
return self.activationFunction.evaluateDerivative()
def equals(self, other):
result = True
if (len(self.weights) != len(other.weights)):
return False
for i in range(len(self.weights)):
result = result and (self.weights[i] == other.weights[i])
result = result and (self.bias == other.bias) and (self.activationFunction == other.activationFunction)
return result
def __str__(self):
return f"weights: {self.weights} \nbias: {self.bias} \nnum connections: {self.numConnections}"