@@ -71,6 +71,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
71
71
(" lstm_input_dim" , po::value<unsigned >()->default_value (60 ), " LSTM input dimension" )
72
72
(" train,t" , " Should training be run?" )
73
73
(" maxit,M" , po::value<unsigned >()->default_value (8000 ), " Maximum number of training iterations" )
74
+ (" tolerance" , po::value<double >()->default_value (0.0 ), " Tolerance on dev uas for stopping training" )
74
75
(" words,w" , po::value<string>(), " Pretrained word embeddings" )
75
76
(" help,h" , " Help" );
76
77
po::options_description dcmdline_options;
@@ -525,6 +526,8 @@ int main(int argc, char** argv) {
525
526
assert (unk_prob >= 0 .); assert (unk_prob <= 1 .);
526
527
const unsigned maxit = conf[" maxit" ].as <unsigned >();
527
528
cerr << " Maximum number of iterations: " << maxit << " \n " ;
529
+ const double tolerance = conf[" tolerance" ].as <double >();
530
+ cerr << " Optimization tolerance: " << tolerance << " \n " ;
528
531
ostringstream os;
529
532
os << " parser_" << (USE_POS ? " pos" : " nopos" )
530
533
<< ' _' << LAYERS
@@ -607,9 +610,12 @@ int main(int argc, char** argv) {
607
610
double llh = 0 ;
608
611
bool first = true ;
609
612
unsigned iter = 0 ;
613
+ double uas = -1 ;
614
+ double prev_uas = -1 ;
610
615
time_t time_start = std::chrono::system_clock::to_time_t (std::chrono::system_clock::now ());
611
616
cerr << " TRAINING STARTED AT: " << put_time (localtime (&time_start), " %c %Z" ) << endl;
612
- while (!requested_stop && iter < maxit) {
617
+ while (!requested_stop && iter < maxit &&
618
+ (uas < 0 || prev_uas < 0 || abs (prev_uas - uas) > tolerance)) {
613
619
for (unsigned sii = 0 ; sii < status_every_i_iterations; ++sii) {
614
620
if (si == corpus.nsentences ) {
615
621
si = 0 ;
@@ -675,7 +681,9 @@ int main(int argc, char** argv) {
675
681
total_heads += sentence.size () - 1 ;
676
682
}
677
683
auto t_end = std::chrono::high_resolution_clock::now ();
678
- cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences ) << " )\t llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << " \t [" << dev_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
684
+ prev_uas = uas;
685
+ uas = correct_heads / total_heads;
686
+ cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences ) << " )\t llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << uas << " \t [" << dev_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
679
687
if (correct_heads > best_correct_heads) {
680
688
best_correct_heads = correct_heads;
681
689
ofstream out (fname);
@@ -698,6 +706,8 @@ int main(int argc, char** argv) {
698
706
}
699
707
if (iter >= maxit) {
700
708
cerr << " \n Maximum number of iterations reached (" << iter << " ), terminating optimization...\n " ;
709
+ } else if (!requested_stop) {
710
+ cerr << " \n Score tolerance reached (" << tolerance << " ), terminating optimization...\n " ;
701
711
}
702
712
} // should do training?
703
713
if (true ) { // do test evaluation
0 commit comments