diff --git a/ml4h/metrics.py b/ml4h/metrics.py index 79f636bfc..963816b7c 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -833,3 +833,29 @@ def result(self): def reset_state(self): self.is_tracker.reset_state() + +class MultiScaleSSIM(keras.metrics.Metric): + def __init__(self, max_val=6.0, name="multi_scale_ssim", **kwargs): + super(MultiScaleSSIM, self).__init__(name=name, **kwargs) + self.max_val = max_val + self.total_ssim = self.add_weight(name="total_ssim", initializer="zeros") + self.count = self.add_weight(name="count", initializer="zeros") + + def update_state(self, y_true, y_pred, sample_weight=None): + # Calculate MS-SSIM for the batch + ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=self.max_val) + if sample_weight is not None: + ssim = tf.multiply(ssim, sample_weight) + + # Update total MS-SSIM and count + self.total_ssim.assign_add(tf.reduce_sum(ssim)) + self.count.assign_add(tf.cast(tf.size(ssim), tf.float32)) + + def result(self): + # Return the mean MS-SSIM over all batches + return tf.divide(self.total_ssim, self.count) + + def reset_states(self): + # Reset the metric state variables + self.total_ssim.assign(0.0) + self.count.assign(0.0) \ No newline at end of file diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index abdffe543..a0140cd32 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -12,7 +12,7 @@ from keras import layers from ml4h.defines import IMAGE_EXT -from ml4h.metrics import KernelInceptionDistance, InceptionScore +from ml4h.metrics import KernelInceptionDistance, InceptionScore, MultiScaleSSIM from ml4h.models.Block import Block from ml4h.TensorMap import TensorMap @@ -669,7 +669,7 @@ def compile(self, **kwargs): self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss") if self.input_map.axes() == 3 and self.inspect_model: self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299) - self.inception_score = InceptionScore(name = "is", input_shape = self.input_map.shape, kernel_image_size=299) + self.ms_ssim = MultiScaleSSIM() @property def metrics(self): @@ -678,7 +678,7 @@ def metrics(self): m.append(self.supervised_loss_tracker) if self.input_map.axes() == 3 and self.inspect_model: m.append(self.kid) - m.append(self.inception_score) + m.append(self.ms_ssim) return m def denormalize(self, images): @@ -886,7 +886,7 @@ def test_step(self, batch): num_images=self.batch_size, diffusion_steps=20 ) self.kid.update_state(images, generated_images) - self.inception_score.update_state(images) + self.ms_ssim.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics}