Skip to content

Commit ecbf659

Browse files
committed
Fix #3: allow limiting optimization by dev uas tolerance
1 parent 253eed0 commit ecbf659

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

parser/lstm-parse.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
7171
("lstm_input_dim", po::value<unsigned>()->default_value(60), "LSTM input dimension")
7272
("train,t", "Should training be run?")
7373
("maxit,M", po::value<unsigned>()->default_value(8000), "Maximum number of training iterations")
74+
("tolerance", po::value<double>()->default_value(0.0), "Tolerance on dev uas for stopping training")
7475
("words,w", po::value<string>(), "Pretrained word embeddings")
7576
("help,h", "Help");
7677
po::options_description dcmdline_options;
@@ -525,6 +526,8 @@ int main(int argc, char** argv) {
525526
assert(unk_prob >= 0.); assert(unk_prob <= 1.);
526527
const unsigned maxit = conf["maxit"].as<unsigned>();
527528
cerr << "Maximum number of iterations: " << maxit << "\n";
529+
const double tolerance = conf["tolerance"].as<double>();
530+
cerr << "Optimization tolerance: " << tolerance << "\n";
528531
ostringstream os;
529532
os << "parser_" << (USE_POS ? "pos" : "nopos")
530533
<< '_' << LAYERS
@@ -607,9 +610,12 @@ int main(int argc, char** argv) {
607610
double llh = 0;
608611
bool first = true;
609612
unsigned iter = 0;
613+
double uas = -1;
614+
double prev_uas = -1;
610615
time_t time_start = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
611616
cerr << "TRAINING STARTED AT: " << put_time(localtime(&time_start), "%c %Z") << endl;
612-
while(!requested_stop && iter < maxit) {
617+
while(!requested_stop && iter < maxit &&
618+
(uas < 0 || prev_uas < 0 || abs(prev_uas - uas) > tolerance)) {
613619
for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) {
614620
if (si == corpus.nsentences) {
615621
si = 0;
@@ -675,7 +681,9 @@ int main(int argc, char** argv) {
675681
total_heads += sentence.size() - 1;
676682
}
677683
auto t_end = std::chrono::high_resolution_clock::now();
678-
cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << dev_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
684+
prev_uas = uas;
685+
uas = correct_heads / total_heads;
686+
cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << uas << "\t[" << dev_size << " sents in " << std::chrono::duration<double, std::milli>(t_end-t_start).count() << " ms]" << endl;
679687
if (correct_heads > best_correct_heads) {
680688
best_correct_heads = correct_heads;
681689
ofstream out(fname);
@@ -698,6 +706,8 @@ int main(int argc, char** argv) {
698706
}
699707
if (iter >= maxit) {
700708
cerr << "\nMaximum number of iterations reached (" << iter << "), terminating optimization...\n";
709+
} else if (!requested_stop) {
710+
cerr << "\nScore tolerance reached (" << tolerance << "), terminating optimization...\n";
701711
}
702712
} // should do training?
703713
if (true) { // do test evaluation

0 commit comments

Comments
 (0)