diff --git a/ml4h/metrics.py b/ml4h/metrics.py index 58b20fb92..e68c21d67 100755 --- a/ml4h/metrics.py +++ b/ml4h/metrics.py @@ -842,7 +842,7 @@ def __init__(self, name="multi_scale_ssim", **kwargs): def update_state(self, y_true, y_pred, max_val, sample_weight=None): # Calculate MS-SSIM for the batch - ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=max_val, power_factors=[0.0, 0.0, 0.5, 0.5]) + ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=max_val, power_factors=[0.1, 0.2, 0.4, 0.3]) if sample_weight is not None: ssim = tf.multiply(ssim, sample_weight) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 486287ed5..e2277ab8d 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -886,8 +886,6 @@ def test_step(self, batch): num_images=self.batch_size, diffusion_steps=20 ) self.kid.update_state(images, generated_images) - max_pixel_value = tf.reduce_max(tf.abs(generated_images)) - max_val = 2 * max_pixel_value # Double the max absolute value self.ms_ssim.update_state(images, generated_images, 255) return {m.name: m.result() for m in self.metrics}