diff --git a/src/worker.py b/src/worker.py index 47467067..4e3ef27a 100644 --- a/src/worker.py +++ b/src/worker.py @@ -1565,6 +1565,7 @@ def compute_GAN_train_or_test_classifier_accuracy_score(self, GAN_train=False, G train_top1_acc, train_top5_acc, train_loss = misc.AverageMeter(), misc.AverageMeter(), misc.AverageMeter() for i, (images, labels) in enumerate(self.train_dataloader): + optimizer.zero_grad() if GAN_train: images, labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior, truncation_factor=self.RUN.truncation_factor, @@ -1616,11 +1617,12 @@ def compute_GAN_train_or_test_classifier_accuracy_score(self, GAN_train=False, G model_ = misc.peel_model(model) states = {"state_dict": model_.state_dict(), "optimizer": optimizer.state_dict(), "epoch": current_epoch+1, "best_top1": best_top1, "best_top5": best_top5, "best_epoch": best_epoch} + if self.local_rank == 0: + self.logger.info("Save model to {}".format(self.RUN.ckpt_dir)) misc.save_model_c(states, mode, self.RUN) if self.local_rank == 0: self.logger.info("Current best accuracy: Top-1: {top1:.4f}% and Top-5 {top5:.4f}%".format(top1=best_top1, top5=best_top5)) - self.logger.info("Save model to {}".format(self.RUN.ckpt_dir)) # ----------------------------------------------------------------------------- # validate GAN_train or GAN_test classifier using generated or training dataset @@ -1628,7 +1630,7 @@ def compute_GAN_train_or_test_classifier_accuracy_score(self, GAN_train=False, G def validate_classifier(self,model, generator, generator_mapping, generator_synthesis, epoch, GAN_test, setting): model.eval() valid_top1_acc, valid_top5_acc, valid_loss = misc.AverageMeter(), misc.AverageMeter(), misc.AverageMeter() - for i, (images, labels) in enumerate(self.train_dataloader): + for i, (images, labels) in enumerate(self.eval_dataloader): if GAN_test: images, labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior, truncation_factor=self.RUN.truncation_factor,