-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdataset.h
More file actions
33 lines (28 loc) · 1019 Bytes
/
dataset.h
File metadata and controls
33 lines (28 loc) · 1019 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#pragma once
#include <vector>
#include <string>
namespace nn
{
struct Dataset
{
std::vector<float> train_data;
std::vector<float> train_labels;
std::vector<float> test_data;
std::vector<float> test_labels;
std::vector<unsigned> element_size;
unsigned label_size = 1;
unsigned train_elements = 0;
unsigned test_elements = 0;
};
//save and load already prepared dataset to/from binary file
void save_dataset(std::string path, const Dataset *dataset);
void load_dataset(std::string path, Dataset *out_dataset);
void train_test_split(Dataset *dataset, float test_fraction = 0.1);
//labels are trated as one-hot encoded class marks
//this function duplicates entries of rare classes
//and checks if all classes exist in dataset
void rebalance_classes(Dataset *out_dataset);
//read datasets from their specific binary formats
void read_CIFAR10_dataset(std::string path, Dataset *out_dataset);
void read_MNIST_dataset(std::string path, Dataset *out_dataset);
}