@@ -845,7 +845,8 @@ void signal_callback_handler(int /* signum */) {
845845 requested_stop = true ;
846846}
847847
848- unsigned compute_correct (const map<int ,int >& ref, const map<int ,int >& hyp, unsigned len) {
848+ template <typename T>
849+ unsigned compute_correct (const map<int ,T>& ref, const map<int ,T>& hyp, unsigned len) {
849850 unsigned res = 0 ;
850851 for (unsigned i = 0 ; i < len; ++i) {
851852 auto ri = ref.find (i);
@@ -857,6 +858,24 @@ unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsig
857858 return res;
858859}
859860
861+ template <typename T1, typename T2>
862+ unsigned compute_correct (const map<int ,T1>& ref1, const map<int ,T1>& hyp1,
863+ const map<int ,T2>& ref2, const map<int ,T2>& hyp2, unsigned len) {
864+ unsigned res = 0 ;
865+ for (unsigned i = 0 ; i < len; ++i) {
866+ auto r1 = ref1.find (i);
867+ auto h1 = hyp1.find (i);
868+ auto r2 = ref2.find (i);
869+ auto h2 = hyp2.find (i);
870+ assert (r1 != ref1.end ());
871+ assert (h1 != hyp1.end ());
872+ assert (r2 != ref2.end ());
873+ assert (h2 != hyp2.end ());
874+ if (r1->second == h1->second && r2->second == h2->second ) ++res;
875+ }
876+ return res;
877+ }
878+
860879void output_conll (const vector<unsigned >& sentence, const vector<unsigned >& pos,
861880 const vector<string>& sentenceUnkStrings,
862881 const map<unsigned , string>& intToWords,
@@ -1142,7 +1161,8 @@ int main(int argc, char** argv) {
11421161 double llh = 0 ;
11431162 double trs = 0 ;
11441163 double right = 0 ;
1145- double correct_heads = 0 ;
1164+ double correct_heads_unlabeled = 0 ;
1165+ double correct_heads_labeled = 0 ;
11461166 double total_heads = 0 ;
11471167 auto t_start = std::chrono::high_resolution_clock::now ();
11481168 unsigned corpus_size = corpus.nsentencesDev ;
@@ -1169,11 +1189,12 @@ int main(int argc, char** argv) {
11691189 map<int ,int > ref = parser.compute_heads (sentence.size (), actions, corpus.actions , &rel_ref);
11701190 map<int ,int > hyp = parser.compute_heads (sentence.size (), pred, corpus.actions , &rel_hyp);
11711191 output_conll (sentence, sentencePos, sentenceUnkStr, corpus.intToWords , corpus.intToPos , hyp, rel_hyp);
1172- correct_heads += compute_correct (ref, hyp, sentence.size () - 1 );
1192+ correct_heads_unlabeled += compute_correct (ref, hyp, sentence.size () - 1 );
1193+ correct_heads_labeled += compute_correct (ref, hyp, rel_ref, rel_hyp, sentence.size () - 1 );
11731194 total_heads += sentence.size () - 1 ;
11741195 }
11751196 auto t_end = std::chrono::high_resolution_clock::now ();
1176- cerr << " TEST llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << " \t [" << corpus_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
1197+ cerr << " TEST llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << " \t [" << corpus_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
11771198 }
11781199 for (unsigned i = 0 ; i < corpus.actions .size (); ++i) {
11791200 // cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;
0 commit comments