@@ -428,7 +428,8 @@ void signal_callback_handler(int /* signum */) {
428
428
requested_stop = true ;
429
429
}
430
430
431
- unsigned compute_correct (const map<int ,int >& ref, const map<int ,int >& hyp, unsigned len) {
431
+ template <typename T>
432
+ unsigned compute_correct (const map<int ,T>& ref, const map<int ,T>& hyp, unsigned len) {
432
433
unsigned res = 0 ;
433
434
for (unsigned i = 0 ; i < len; ++i) {
434
435
auto ri = ref.find (i);
@@ -440,6 +441,24 @@ unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsig
440
441
return res;
441
442
}
442
443
444
+ template <typename T1, typename T2>
445
+ unsigned compute_correct (const map<int ,T1>& ref1, const map<int ,T1>& hyp1,
446
+ const map<int ,T2>& ref2, const map<int ,T2>& hyp2, unsigned len) {
447
+ unsigned res = 0 ;
448
+ for (unsigned i = 0 ; i < len; ++i) {
449
+ auto r1 = ref1.find (i);
450
+ auto h1 = hyp1.find (i);
451
+ auto r2 = ref2.find (i);
452
+ auto h2 = hyp2.find (i);
453
+ assert (r1 != ref1.end ());
454
+ assert (h1 != hyp1.end ());
455
+ assert (r2 != ref2.end ());
456
+ assert (h2 != hyp2.end ());
457
+ if (r1->second == h1->second && r2->second == h2->second ) ++res;
458
+ }
459
+ return res;
460
+ }
461
+
443
462
void output_conll (const vector<unsigned >& sentence, const vector<unsigned >& pos,
444
463
const vector<string>& sentenceUnkStrings,
445
464
const map<unsigned , string>& intToWords,
@@ -714,7 +733,8 @@ int main(int argc, char** argv) {
714
733
double llh = 0 ;
715
734
double trs = 0 ;
716
735
double right = 0 ;
717
- double correct_heads = 0 ;
736
+ double correct_heads_unlabeled = 0 ;
737
+ double correct_heads_labeled = 0 ;
718
738
double total_heads = 0 ;
719
739
auto t_start = std::chrono::high_resolution_clock::now ();
720
740
unsigned corpus_size = corpus.nsentencesDev ;
@@ -736,11 +756,12 @@ int main(int argc, char** argv) {
736
756
map<int ,int > ref = parser.compute_heads (sentence.size (), actions, corpus.actions , &rel_ref);
737
757
map<int ,int > hyp = parser.compute_heads (sentence.size (), pred, corpus.actions , &rel_hyp);
738
758
output_conll (sentence, sentencePos, sentenceUnkStr, corpus.intToWords , corpus.intToPos , hyp, rel_hyp);
739
- correct_heads += compute_correct (ref, hyp, sentence.size () - 1 );
759
+ correct_heads_unlabeled += compute_correct (ref, hyp, sentence.size () - 1 );
760
+ correct_heads_labeled += compute_correct (ref, hyp, rel_ref, rel_hyp, sentence.size () - 1 );
740
761
total_heads += sentence.size () - 1 ;
741
762
}
742
763
auto t_end = std::chrono::high_resolution_clock::now ();
743
- cerr << " TEST llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << " \t [" << corpus_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
764
+ cerr << " TEST llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << " \t [" << corpus_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
744
765
}
745
766
for (unsigned i = 0 ; i < corpus.actions .size (); ++i) {
746
767
// cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;
0 commit comments