Skip to content

Commit c9c9423

Browse files
committed
Fix #2: calculate and print las on test
1 parent aa1b1db commit c9c9423

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

parser/lstm-parse.cc

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,8 @@ void signal_callback_handler(int /* signum */) {
845845
requested_stop = true;
846846
}
847847

848-
unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsigned len) {
848+
template<typename T>
849+
unsigned compute_correct(const map<int,T>& ref, const map<int,T>& hyp, unsigned len) {
849850
unsigned res = 0;
850851
for (unsigned i = 0; i < len; ++i) {
851852
auto ri = ref.find(i);
@@ -857,6 +858,24 @@ unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsig
857858
return res;
858859
}
859860

861+
template<typename T1, typename T2>
862+
unsigned compute_correct(const map<int,T1>& ref1, const map<int,T1>& hyp1,
863+
const map<int,T2>& ref2, const map<int,T2>& hyp2, unsigned len) {
864+
unsigned res = 0;
865+
for (unsigned i = 0; i < len; ++i) {
866+
auto r1 = ref1.find(i);
867+
auto h1 = hyp1.find(i);
868+
auto r2 = ref2.find(i);
869+
auto h2 = hyp2.find(i);
870+
assert(r1 != ref1.end());
871+
assert(h1 != hyp1.end());
872+
assert(r2 != ref2.end());
873+
assert(h2 != hyp2.end());
874+
if (r1->second == h1->second && r2->second == h2->second) ++res;
875+
}
876+
return res;
877+
}
878+
860879
void output_conll(const vector<unsigned>& sentence, const vector<unsigned>& pos,
861880
const vector<string>& sentenceUnkStrings,
862881
const map<unsigned, string>& intToWords,
@@ -1142,7 +1161,8 @@ int main(int argc, char** argv) {
11421161
double llh = 0;
11431162
double trs = 0;
11441163
double right = 0;
1145-
double correct_heads = 0;
1164+
double correct_heads_unlabeled = 0;
1165+
double correct_heads_labeled = 0;
11461166
double total_heads = 0;
11471167
auto t_start = std::chrono::high_resolution_clock::now();
11481168
unsigned corpus_size = corpus.nsentencesDev;
@@ -1169,11 +1189,12 @@ int main(int argc, char** argv) {
11691189
map<int,int> ref = parser.compute_heads(sentence.size(), actions, corpus.actions, &rel_ref);
11701190
map<int,int> hyp = parser.compute_heads(sentence.size(), pred, corpus.actions, &rel_hyp);
11711191
output_conll(sentence, sentencePos, sentenceUnkStr, corpus.intToWords, corpus.intToPos, hyp, rel_hyp);
1172-
correct_heads += compute_correct(ref, hyp, sentence.size() - 1);
1192+
correct_heads_unlabeled += compute_correct(ref, hyp, sentence.size() - 1);
1193+
correct_heads_labeled += compute_correct(ref, hyp, rel_ref, rel_hyp, sentence.size() - 1);
11731194
total_heads += sentence.size() - 1;
11741195
}
11751196
auto t_end = std::chrono::high_resolution_clock::now();
1176-
cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
1197+
cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
11771198
}
11781199
for (unsigned i = 0; i < corpus.actions.size(); ++i) {
11791200
//cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;

0 commit comments

Comments
 (0)