diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index fa1c890006..a34cef9280 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -37,7 +37,8 @@ def __init__(self, base_path, model, model_opt, fields, optim, if keep_checkpoint > 0: self.checkpoint_queue = deque([], maxlen=keep_checkpoint) - def save(self, step, moving_average=None): + def save(self, step, moving_average=None, best_step=None, + validation_ppl=None): """Main entry point for model saver It wraps the `_save` method with checks and apply `keep_checkpoint` @@ -63,9 +64,22 @@ def save(self, step, moving_average=None): param.data = param_data if self.keep_checkpoint > 0: + best_step, is_best = self._get_best_checkpoint(best_step, + validation_ppl) + if is_best: + if best_step is None: + best_checkpoint = '%s_step_%d.pt' % (self.base_path, step) + self._update_best_config(step, validation_ppl) + else: + best_checkpoint = '%s_step_%d.pt' \ + % (self.base_path, best_step) + self._update_best_config(best_step, validation_ppl) + else: + best_checkpoint = '%s_step_%d.pt' % (self.base_path, best_step) if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: todel = self.checkpoint_queue.popleft() - self._rm_checkpoint(todel) + if todel != best_checkpoint: + self._rm_checkpoint(todel) self.checkpoint_queue.append(chkpt_name) def _save(self, step): @@ -133,3 +147,26 @@ def _save(self, step, model): def _rm_checkpoint(self, name): if os.path.exists(name): os.remove(name) + + def _get_best_checkpoint(self, best_step, validation_ppl): + import json + is_best = False + best_ckpt_config_file = self.base_path + 'best_ckpt_config.json' + if os.path.exists(best_ckpt_config_file): + with open(best_ckpt_config_file, 'r') as best_config: + best_ckpt_dict = json.load(best_config) + if validation_ppl < best_ckpt_dict['validation_ppl']: + is_best = True + else: + best_step = best_ckpt_dict['step'] + else: + is_best = True + + return best_step, is_best + + def _update_best_config(self, step, validation_ppl): + import json + best_ckpt_config_file = self.base_path + 'best_ckpt_config.json' + with open(best_ckpt_config_file, 'w') as best_config: + best_ckpt_dict = {'step': step, 'validation_ppl': validation_ppl} + json.dump(best_ckpt_dict, best_config) diff --git a/onmt/trainer.py b/onmt/trainer.py index 334ab02f12..8ea5c34864 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -241,6 +241,7 @@ def train(self, total_stats = onmt.utils.Statistics() report_stats = onmt.utils.Statistics() self._start_report_manager(start_time=total_stats.start_time) + early_stopped = False for i, (batches, normalization) in enumerate( self._accum_batches(train_iter)): @@ -272,7 +273,17 @@ def train(self, self.optim.learning_rate(), report_stats) - if valid_iter is not None and step % valid_steps == 0: + save_at_current_step = self.model_saver is not None \ + and (save_checkpoint_steps != 0 and + step % save_checkpoint_steps == 0) + validate_at_current_step = ( + valid_iter is not None and step % valid_steps == 0 + ) + is_final_step = step == train_steps + + # Force validation in any of the above cases + if (save_at_current_step or validate_at_current_step + or is_final_step): if self.gpu_verbose_level > 0: logger.info('GpuRank %d: validate step %d' % (self.gpu_rank, step)) @@ -287,23 +298,31 @@ def train(self, % (self.gpu_rank, step)) self._report_step(self.optim.learning_rate(), step, valid_stats=valid_stats) - # Run patience mechanism - if self.earlystopper is not None: - self.earlystopper(valid_stats, step) - # If the patience has reached the limit, stop training - if self.earlystopper.has_stopped(): - break - - if (self.model_saver is not None - and (save_checkpoint_steps != 0 - and step % save_checkpoint_steps == 0)): - self.model_saver.save(step, moving_average=self.moving_average) + + if validate_at_current_step: + # Run patience mechanism + if self.earlystopper is not None: + self.earlystopper(valid_stats, step) + # If the patience has reached the limit, stop training + if self.earlystopper.has_stopped(): + early_stopped = True + break + if save_at_current_step: + self.model_saver.save(step, + moving_average=self.moving_average, + validation_ppl=valid_stats.ppl()) if train_steps > 0 and step >= train_steps: break if self.model_saver is not None: - self.model_saver.save(step, moving_average=self.moving_average) + best_step = None + if early_stopped: + best_step = self.earlystopper.current_step_best + + self.model_saver.save(step, moving_average=self.moving_average, + best_step=best_step, + validation_ppl=valid_stats.ppl()) return total_stats def validate(self, valid_iter, moving_average=None):