diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index e7b91e167..03a387484 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -315,6 +315,7 @@ def metrics(self): m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric] if self.tensor_map.axes() == 3 and self.inspect_model: m.append(self.kid) + m.append(self.inception_score) return m def denormalize(self, images): @@ -487,6 +488,7 @@ def test_step(self, images_original): num_images=self.batch_size, diffusion_steps=20 ) self.kid.update_state(images, generated_images) + self.inception_score.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} @@ -676,6 +678,7 @@ def metrics(self): m.append(self.supervised_loss_tracker) if self.input_map.axes() == 3 and self.inspect_model: m.append(self.kid) + m.append(self.inception_score) return m def denormalize(self, images): @@ -883,6 +886,7 @@ def test_step(self, batch): num_images=self.batch_size, diffusion_steps=20 ) self.kid.update_state(images, generated_images) + self.inception_score.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics}