-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNeuralNetwork.hpp
More file actions
53 lines (41 loc) · 1.94 KB
/
NeuralNetwork.hpp
File metadata and controls
53 lines (41 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// NeuralNetwork.hpp
#include <Eigen/Eigen>
#include <iostream>
#include <vector>
// use typedefs for future ease for changing data types like : float to double
typedef float Scalar;
typedef Eigen::MatrixXf Matrix;
typedef Eigen::RowVectorXf RowVector;
typedef Eigen::VectorXf ColVector;
// neural network implementation class!
class NeuralNetwork {
public:
std::vector<uint> topology;
// constructor
NeuralNetwork(std::vector<uint> topology, Scalar learningRate = Scalar(0.005));
// function for forward propagation of data
void propagateForward(RowVector& input);
// function for backward propagation of errors made by neurons
void propagateBackward(RowVector& output);
// function to calculate errors made by neurons in each layer
void calcErrors(RowVector& output);
// function to update the weights of connections
void updateWeights();
// function to train the neural network give an array of data points
void train(std::vector<RowVector*> input, std::vector<RowVector*> output);
// storage objects for working of neural network
/*
use pointers when using std::vector<Class> as std::vector<Class> calls destructor of
Class as soon as it is pushed back! when we use pointers it can't do that, besides
it also makes our neural network class less heavy!! It would be nice if you can use
smart pointers instead of usual ones like this
*/
std::vector<RowVector*> neuronLayers; // stores the different layers of out network
std::vector<RowVector*> cacheLayers; // stores the unactivated (activation fn not yet applied) values of layers
std::vector<RowVector*> deltas; // stores the error contribution of each neurons
std::vector<Matrix*> weights; // the connection weights itself
Scalar learningRate;
};
Scalar activationFunction(Scalar x);
void ReadCSV(std::string filename, std::vector<RowVector*>& data);
void genData(std::string filename);