Skip to content

Commit aa1b1db

Browse files
committed
Fix #3: allow limiting optimization by dev uas tolerance
1 parent 91a41a6 commit aa1b1db

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

parser/lstm-parse.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
7777
("lstm_input_dim", po::value<unsigned>()->default_value(60), "LSTM input dimension")
7878
("train,t", "Should training be run?")
7979
("maxit,M", po::value<unsigned>()->default_value(8000), "Maximum number of training iterations")
80+
("tolerance", po::value<double>()->default_value(0.0), "Tolerance on dev uas for stopping training")
8081
("words,w", po::value<string>(), "Pretrained word embeddings")
8182
("use_spelling,S", "Use spelling model") //Miguel. Spelling model
8283
("help,h", "Help");
@@ -946,6 +947,8 @@ int main(int argc, char** argv) {
946947
assert(unk_prob >= 0.); assert(unk_prob <= 1.);
947948
const unsigned maxit = conf["maxit"].as<unsigned>();
948949
cerr << "Maximum number of iterations: " << maxit << "\n";
950+
const double tolerance = conf["tolerance"].as<double>();
951+
cerr << "Optimization tolerance: " << tolerance << "\n";
949952
ostringstream os;
950953
os << "parser_" << (USE_POS ? "pos" : "nopos")
951954
<< '_' << LAYERS
@@ -1035,7 +1038,10 @@ int main(int argc, char** argv) {
10351038
double llh = 0;
10361039
bool first = true;
10371040
unsigned iter = 0;
1038-
while(!requested_stop && iter < maxit) {
1041+
double uas = -1;
1042+
double prev_uas = -1;
1043+
while(!requested_stop && iter < maxit &&
1044+
(uas < 0 || prev_uas < 0 || abs(prev_uas - uas) > tolerance)) {
10391045
for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) {
10401046
if (si == corpus.nsentences) {
10411047
si = 0;
@@ -1103,7 +1109,9 @@ int main(int argc, char** argv) {
11031109
total_heads += sentence.size() - 1;
11041110
}
11051111
auto t_end = std::chrono::high_resolution_clock::now();
1106-
cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << dev_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
1112+
prev_uas = uas;
1113+
uas = correct_heads / total_heads;
1114+
cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << uas << "\t[" << dev_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
11071115
if (correct_heads > best_correct_heads) {
11081116
best_correct_heads = correct_heads;
11091117
ofstream out(fname);
@@ -1126,6 +1134,8 @@ int main(int argc, char** argv) {
11261134
}
11271135
if (iter >= maxit) {
11281136
cerr << "\nMaximum number of iterations reached (" << iter << "), terminating optimization...\n";
1137+
} else if (!requested_stop) {
1138+
cerr << "\nScore tolerance reached (" << tolerance << "), terminating optimization...\n";
11291139
}
11301140
} // should do training?
11311141
if (true) { // do test evaluation

0 commit comments

Comments
 (0)