-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathnnetTest.m
More file actions
21 lines (17 loc) · 780 Bytes
/
nnetTest.m
File metadata and controls
21 lines (17 loc) · 780 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
function [predicted, errors] = nnetTest(data, labels, cost, layers)
%% [predicted, errors] = nnetTest(data, labels, params, layers) tests the network on new datapoints.
%%
%% Inputs:
%% - data is the matrix containing the datapoints, one per row;
%% - labels is the matrix containing the labels, one per row;
%% - cost is a string containing the error measure we are interested in.
%% Possible values are 'mse', 'ce', 'nll' and 'class'.
%% - layers is the structure containing the trained network.
%%
%% Outputs:
%% - predicted is the predicted output for each datapoint;
%% - errors is the list of errors for each datapoint.
% Forward propagation.
layers = fprop(layers, data);
predicted = layers(end).output;
errors = computeCost(layers, labels, cost);