-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.cpp
More file actions
executable file
·199 lines (189 loc) · 7.72 KB
/
main.cpp
File metadata and controls
executable file
·199 lines (189 loc) · 7.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
//
// Created by Yujia Shen on 4/12/18.
//
extern "C" {
#include <sdd/sddapi.h>
}
#include <structured_bn/network.h>
#include <structured_bn/network_compiler.h>
#include <util/optionparser.h>
#include <cassert>
#include <chrono>
#include <iostream>
using ms = std::chrono::milliseconds;
using get_time = std::chrono::steady_clock;
struct Arg : public option::Arg {
static void printError(const char *msg1, const option::Option &opt,
const char *msg2) {
fprintf(stderr, "%s", msg1);
fwrite(opt.name, (size_t)opt.namelen, 1, stderr);
fprintf(stderr, "%s", msg2);
}
static option::ArgStatus Required(const option::Option &option, bool msg) {
if (option.arg != 0) return option::ARG_OK;
if (msg) printError("Option '", option, "' requires an argument\n");
return option::ARG_ILLEGAL;
}
static option::ArgStatus Numeric(const option::Option &option, bool msg) {
char *endptr = 0;
if (option.arg != 0 && strtol(option.arg, &endptr, 10)) {
};
if (endptr != option.arg && *endptr == 0) return option::ARG_OK;
if (msg) printError("Option '", option, "' requires a numeric argument\n");
return option::ARG_ILLEGAL;
}
};
enum optionIndex {
UNKNOWN,
HELP,
SPARSE_LEARNING_DATASET_FILE,
LEARNING_DATASET_FILE,
PSDD_FILENAME,
VTREE_FILENAME,
CONSISTENT_CHECK,
SAMPLE_PARAMETER,
SEED
};
const option::Descriptor usage[] = {
{UNKNOWN, 0, "", "", option::Arg::None,
"USAGE: example [options]\n\n \tOptions:"},
{HELP, 0, "h", "help", option::Arg::None,
"--help \tPrint usage and exit."},
{SPARSE_LEARNING_DATASET_FILE, 0, "", "sparse_learning_dataset",
Arg::Required,
"--sparse_learning_dataset Set sparse dataset file which is used to learn "
"parameters in the SBN\""},
{LEARNING_DATASET_FILE, 0, "", "learning_dataset", Arg::Required,
"--learning_dataset Set dataset file which is used to learn parameters in "
"the SBN"},
{PSDD_FILENAME, 0, "", "psdd_filename", Arg::Required,
"--psdd_filename the output filename for the compiled psdd."},
{VTREE_FILENAME, 0, "", "vtree_filename", Arg::Required,
"--vtree_filename the output filename for joint vtree"},
{CONSISTENT_CHECK, 0, "", "consistent_check", option::Arg::None,
"--consistent_check \tCheck whether learning data is consistent"},
{SAMPLE_PARAMETER, 0, "", "sample_parameter", option::Arg::None,
"--sample_parameter \t Sample parameter from Gamma distribution"},
{SEED, 0, "s", "seed", Arg::Required,
"--seed \t Seed to be used. default is 0"},
{UNKNOWN, 0, "", "", option::Arg::None,
"\nExamples:\n./structured_bn_main --psdd_filename <psdd_filename> "
"--vtree_filename <vtree_filename> network.json\n"},
{0, 0, 0, 0, 0, 0}};
using structured_bn::Network;
using structured_bn::NetworkCompiler;
int main(int argc, const char *argv[]) {
argc -= (argc > 0);
argv += (argc > 0); // skip program name argv[0] if present
option::Stats stats(usage, argc, argv);
std::vector<option::Option> options(stats.options_max);
std::vector<option::Option> buffer(stats.buffer_max);
option::Parser parse(usage, argc, argv, &options[0], &buffer[0]);
if (parse.error()) return 1;
if (options[HELP] || argc == 0) {
option::printUsage(std::cout, usage);
return 0;
}
const char *network_file = parse.nonOption(0);
uint seed = 0;
if (options[SEED]) {
seed = (uint)std::strtol(options[SEED].arg, nullptr, 10);
}
// Create Network from user-specified SBN
std::cout << "Loading Network File " << network_file << std::endl;
auto start = get_time::now();
Network *network = Network::GetNetworkFromSpecFile(network_file);
auto end = get_time::now();
std::cout << "Network Loading Time : "
<< std::chrono::duration_cast<ms>(end - start).count() << " ms"
<< std::endl;
// Generate random parameters for each parameter set.
if (options[SAMPLE_PARAMETER]) {
std::cout << "Sample parameters in the network" << std::endl;
RandomDoubleFromGammaGenerator generator(1.0, 1.0, seed);
network->SampleParameters(&generator);
} else {
// Create training dataset
BinaryData *train_data = nullptr;
// Use user-specified data file
if (options[LEARNING_DATASET_FILE]) {
const char *data_file = options[LEARNING_DATASET_FILE].arg;
std::cout << "Learning parameters from data file " << data_file
<< std::endl;
train_data = new BinaryData();
train_data->ReadFile(data_file);
}
// Use user-specified sparse data file
else if (options[SPARSE_LEARNING_DATASET_FILE]) {
const char *data_file = options[SPARSE_LEARNING_DATASET_FILE].arg;
std::cout << "Learning parameter from sparse data file " << data_file
<< std::endl;
train_data = BinaryData::ReadSparseDataJsonFile(data_file);
}
// Use empty dataset
else {
std::cout << "Learning parameter with 0 data" << std::endl;
train_data = new BinaryData();
}
// Learn the prior bias using Laplacian Smoothing
// Laplacian Smoothing, aka "add one smoothing", adds a pseudocount of one to
// each parameter in each parameter set, before normalization.
// PsddParameter::CreateFromDecimal(1) - treats 1 as 1, 2 as 2, and so on (decimal space)
// PsddParameter::CreateFromLog(1) - treats 1 as e, 2 as e^2, and so on (log space)
// 1 is the pseudocount
start = get_time::now();
network->LearnParametersUsingLaplacianSmoothing(
train_data, PsddParameter::CreateFromDecimal(1));
end = get_time::now();
std::cout << "Learn Parameter Time : "
<< std::chrono::duration_cast<ms>(end - start).count() << " ms"
<< std::endl;
// Output Log Likelihood of training data
Probability training_ll = network->CalculateProbability(train_data);
std::cout << "Training Log Likelihood: " << training_ll.parameter() << std::endl;
std::cout << "Size of data: " << train_data->data_size() << std::endl;
// Check whether the examples in the dataset are consistent with the logical constraints of the SBN
// There is an underlying assumption that both the training and testing sets are consistent with
// the logical constraints of the SBN.
// Used for debugging and testing purposes.
if (options[CONSISTENT_CHECK]) {
assert(train_data != nullptr);
const auto &dataset = train_data->data();
std::bitset<MAX_VAR> mask;
mask.set();
for (const auto &cur_entry : dataset) {
std::cout << "Data : " << cur_entry.first << std::endl;
// IsModel checks whether an instantiation is a model of the SDD or not (i.e., a satisfying assignment)
if (network->IsModel(mask, cur_entry.first)) {
std::cout << "is a Model" << std::endl;
} else {
std::cout << "is not a Model" << std::endl;
}
}
}
delete (train_data);
}
if (options[PSDD_FILENAME]) {
const char *psdd_filename = options[PSDD_FILENAME].arg;
start = get_time::now();
NetworkCompiler *compiler =
NetworkCompiler::GetDefaultNetworkCompiler(network);
auto result = compiler->Compile();
end = get_time::now();
std::cout << "Compile Network Time : "
<< std::chrono::duration_cast<ms>(end - start).count() << " ms"
<< std::endl;
auto model_count = psdd_node_util::ModelCount(
psdd_node_util::SerializePsddNodes(result.first));
std::cout << "Model count " << model_count.get_str(10) << std::endl;
std::cout << "PSDD size" << psdd_node_util::GetPsddSize(result.first)
<< std::endl;
psdd_node_util::WritePsddToFile(result.first, psdd_filename);
if (options[VTREE_FILENAME]) {
const char *vtree_filename = options[VTREE_FILENAME].arg;
sdd_vtree_save(vtree_filename, compiler->GetVtree());
}
delete (compiler);
}
delete (network);
}