Skip to content

Commit 1c557b0

Browse files
committed
Fix #2: calculate and print las on test
1 parent ecbf659 commit 1c557b0

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

parser/lstm-parse.cc

+25-4
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ void signal_callback_handler(int /* signum */) {
428428
requested_stop = true;
429429
}
430430

431-
unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsigned len) {
431+
template<typename T>
432+
unsigned compute_correct(const map<int,T>& ref, const map<int,T>& hyp, unsigned len) {
432433
unsigned res = 0;
433434
for (unsigned i = 0; i < len; ++i) {
434435
auto ri = ref.find(i);
@@ -440,6 +441,24 @@ unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsig
440441
return res;
441442
}
442443

444+
template<typename T1, typename T2>
445+
unsigned compute_correct(const map<int,T1>& ref1, const map<int,T1>& hyp1,
446+
const map<int,T2>& ref2, const map<int,T2>& hyp2, unsigned len) {
447+
unsigned res = 0;
448+
for (unsigned i = 0; i < len; ++i) {
449+
auto r1 = ref1.find(i);
450+
auto h1 = hyp1.find(i);
451+
auto r2 = ref2.find(i);
452+
auto h2 = hyp2.find(i);
453+
assert(r1 != ref1.end());
454+
assert(h1 != hyp1.end());
455+
assert(r2 != ref2.end());
456+
assert(h2 != hyp2.end());
457+
if (r1->second == h1->second && r2->second == h2->second) ++res;
458+
}
459+
return res;
460+
}
461+
443462
void output_conll(const vector<unsigned>& sentence, const vector<unsigned>& pos,
444463
const vector<string>& sentenceUnkStrings,
445464
const map<unsigned, string>& intToWords,
@@ -714,7 +733,8 @@ int main(int argc, char** argv) {
714733
double llh = 0;
715734
double trs = 0;
716735
double right = 0;
717-
double correct_heads = 0;
736+
double correct_heads_unlabeled = 0;
737+
double correct_heads_labeled = 0;
718738
double total_heads = 0;
719739
auto t_start = std::chrono::high_resolution_clock::now();
720740
unsigned corpus_size = corpus.nsentencesDev;
@@ -736,11 +756,12 @@ int main(int argc, char** argv) {
736756
map<int,int> ref = parser.compute_heads(sentence.size(), actions, corpus.actions, &rel_ref);
737757
map<int,int> hyp = parser.compute_heads(sentence.size(), pred, corpus.actions, &rel_hyp);
738758
output_conll(sentence, sentencePos, sentenceUnkStr, corpus.intToWords, corpus.intToPos, hyp, rel_hyp);
739-
correct_heads += compute_correct(ref, hyp, sentence.size() - 1);
759+
correct_heads_unlabeled += compute_correct(ref, hyp, sentence.size() - 1);
760+
correct_heads_labeled += compute_correct(ref, hyp, rel_ref, rel_hyp, sentence.size() - 1);
740761
total_heads += sentence.size() - 1;
741762
}
742763
auto t_end = std::chrono::high_resolution_clock::now();
743-
cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
764+
cerr << "TEST llh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << "\t[" << corpus_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
744765
}
745766
for (unsigned i = 0; i < corpus.actions.size(); ++i) {
746767
//cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;

0 commit comments

Comments
 (0)