diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 1240b8694..fe73c7a45 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -577,7 +577,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc train_losses = [] valid_losses = [] - checkpoint_model = None + saved_model = False if earlystop_patience or earlystop_maxgap: early_stopping = EarlyStopping(patience=earlystop_patience, maxgap=earlystop_maxgap, min_epoch=min_epoch, trace_func=_log.info) @@ -610,6 +610,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc if best_model: if min(valid_losses) == loss_: checkpoint_model = self._save_model() + saved_model = True self.epoch_saved_model = epoch _log.info(f'Best model saved at epoch # {self.epoch_saved_model}.') # check early stopping criteria (in validation case only) @@ -627,17 +628,19 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc "Training data is used both for learning and model selection, which will to overfitting." + "\n\tIt is preferable to use an independent training and validation data sets.") checkpoint_model = self._save_model() + saved_model = True self.epoch_saved_model = epoch _log.info(f'Best model saved at epoch # {self.epoch_saved_model}.') # Save the last model - if best_model is False or checkpoint_model is None: + if best_model is False or not saved_model: checkpoint_model = self._save_model() self.epoch_saved_model = epoch _log.info(f'Last model saved at epoch # {self.epoch_saved_model}.') - if checkpoint_model is None: - _log.warning("A model has been saved but the training and validation losses were NaN;" + - "make sure that you are using enough data points during the trainig.") + if not saved_model: + _log.warning("A model has been saved but the validation and/or the training losses were NaN;" + + "try to increase the cutoff distance during the data processing or the number of data points" + + "during the training.") # Now that the training loop is over, save the model if filename: