diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 39b1088a6..abdffe543 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -886,7 +886,7 @@ def test_step(self, batch): num_images=self.batch_size, diffusion_steps=20 ) self.kid.update_state(images, generated_images) - self.inception_score.update_state(generated_images) + self.inception_score.update_state(images) return {m.name: m.result() for m in self.metrics}