@@ -259,8 +259,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s
259
259
260
260
for i , width in enumerate (reversed (widths [:- 1 ])):
261
261
if attention_modulo > 1 and ((len (widths ) - 1 ) - i ) % attention_modulo == 0 :
262
- if len (input_shape ) > 2 :
263
- c2 = upsample (size = x .shape [1 : - 1 ]* 2 )(control [control_idxs ])
262
+ if len (input_shape ) == 3 :
263
+ c2 = upsample (size = ( x .shape [1 ]* 2 , x . shape [ 2 ] * 2 ) )(control [control_idxs ])
264
264
else :
265
265
c2 = upsample (size = x .shape [- 2 ]* 2 )(control [control_idxs ])
266
266
x = up_block_control (width , block_depth , conv , upsample ,
@@ -284,7 +284,7 @@ def get_control_embed_model(output_maps, control_size):
284
284
285
285
286
286
class DiffusionModel (keras .Model ):
287
- def __init__ (self , tensor_map , batch_size , widths , block_depth , kernel_size , diffusion_loss , sigmoid_beta ):
287
+ def __init__ (self , tensor_map , batch_size , widths , block_depth , kernel_size , diffusion_loss , sigmoid_beta , inspect_model ):
288
288
super ().__init__ ()
289
289
290
290
self .tensor_map = tensor_map
@@ -294,6 +294,7 @@ def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, dif
294
294
self .ema_network = keras .models .clone_model (self .network )
295
295
self .use_sigmoid_loss = diffusion_loss == 'sigmoid'
296
296
self .beta = sigmoid_beta
297
+ self .inspect_model = inspect_model
297
298
298
299
def can_apply (self ):
299
300
return self .tensor_map .axes () > 1
@@ -303,13 +304,15 @@ def compile(self, **kwargs):
303
304
304
305
self .noise_loss_tracker = keras .metrics .Mean (name = "n_loss" )
305
306
self .image_loss_tracker = keras .metrics .Mean (name = "i_loss" )
306
- if self .tensor_map .axes () == 3 :
307
- self .kid = KernelInceptionDistance (name = "kid" , input_shape = self .tensor_map .shape , kernel_image_size = 75 )
307
+ self .mse_metric = tf .keras .metrics .MeanSquaredError (name = "mse" )
308
+ self .mae_metric = tf .keras .metrics .MeanAbsoluteError (name = "mae" )
309
+ if self .tensor_map .axes () == 3 and self .inspect_model :
310
+ self .kid = KernelInceptionDistance (name = "kid" , input_shape = self .tensor_map .shape , kernel_image_size = 299 )
308
311
309
312
@property
310
313
def metrics (self ):
311
- m = [self .noise_loss_tracker , self .image_loss_tracker ]
312
- if self .tensor_map .axes () == 3 :
314
+ m = [self .noise_loss_tracker , self .image_loss_tracker , self . mse_metric , self . mae_metric ]
315
+ if self .tensor_map .axes () == 3 and self . inspect_model :
313
316
m .append (self .kid )
314
317
return m
315
318
@@ -428,13 +431,15 @@ def train_step(self, images_original):
428
431
429
432
self .noise_loss_tracker .update_state (noise_loss )
430
433
self .image_loss_tracker .update_state (image_loss )
434
+ self .mse_metric .update_state (noises , pred_noises )
435
+ self .mae_metric .update_state (noises , pred_noises )
431
436
432
437
# track the exponential moving averages of weights
433
438
for weight , ema_weight in zip (self .network .weights , self .ema_network .weights ):
434
439
ema_weight .assign (ema * ema_weight + (1 - ema ) * weight )
435
440
436
441
# KID is not measured during the training phase for computational efficiency
437
- return {m .name : m .result () for m in self .metrics [: - 1 ] }
442
+ return {m .name : m .result () for m in self .metrics }
438
443
439
444
def test_step (self , images_original ):
440
445
# normalize images to have standard deviation of 1, like the noises
@@ -470,10 +475,12 @@ def test_step(self, images_original):
470
475
471
476
self .image_loss_tracker .update_state (image_loss )
472
477
self .noise_loss_tracker .update_state (noise_loss )
478
+ self .mse_metric .update_state (noises , pred_noises )
479
+ self .mae_metric .update_state (noises , pred_noises )
473
480
474
481
# measure KID between real and generated images
475
482
# this is computationally demanding, kid_diffusion_steps has to be small
476
- if self .tensor_map .axes () == 3 :
483
+ if self .tensor_map .axes () == 3 and self . inspect_model :
477
484
images = self .denormalize (images )
478
485
generated_images = self .generate (
479
486
num_images = self .batch_size , diffusion_steps = 20
@@ -534,15 +541,14 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None,
534
541
plt .close ()
535
542
536
543
def plot_reconstructions (
537
- self , images_original , diffusion_amount = 0 ,
538
- epoch = None , logs = None , num_rows = 3 , num_cols = 6 ,
544
+ self , images_original , diffusion_amount = 0 , epoch = None , logs = None , num_rows = 2 , num_cols = 2 , prefix = './figures/' ,
539
545
):
540
546
images = images_original [0 ][self .tensor_map .input_name ()]
541
547
self .normalizer .update_state (images )
542
548
images = self .normalizer (images , training = False )
543
- noises = tf .random .normal (shape = (self . batch_size ,) + self .tensor_map .shape )
549
+ noises = tf .random .normal (shape = (num_rows * num_cols ,) + self .tensor_map .shape )
544
550
545
- diffusion_times = diffusion_amount * tf .ones (shape = [self . batch_size ] + [1 ] * self .tensor_map .axes ())
551
+ diffusion_times = diffusion_amount * tf .ones (shape = [num_rows * num_cols ] + [1 ] * self .tensor_map .axes ())
546
552
noise_rates , signal_rates = self .diffusion_schedule (diffusion_times )
547
553
# mix the images with noises accordingly
548
554
noisy_images = signal_rates * images + noise_rates * noises
@@ -559,8 +565,27 @@ def plot_reconstructions(
559
565
plt .imshow (generated_images [index ], cmap = 'gray' )
560
566
plt .axis ("off" )
561
567
plt .tight_layout ()
562
- plt .show ()
568
+ now_string = datetime .datetime .now ().strftime ('%Y-%m-%d_%H-%M' )
569
+ figure_path = os .path .join (prefix , f'diffusion_reconstructions_{ now_string } { IMAGE_EXT } ' )
570
+ if not os .path .exists (os .path .dirname (figure_path )):
571
+ os .makedirs (os .path .dirname (figure_path ))
572
+ plt .savefig (figure_path , bbox_inches = "tight" )
563
573
plt .close ()
574
+ plt .figure (figsize = (num_cols * 2.0 , num_rows * 2.0 ), dpi = 300 )
575
+ for row in range (num_rows ):
576
+ for col in range (num_cols ):
577
+ index = row * num_cols + col
578
+ plt .subplot (num_rows , num_cols , index + 1 )
579
+ plt .imshow (images [index ], cmap = 'gray' )
580
+ plt .axis ("off" )
581
+ plt .tight_layout ()
582
+ now_string = datetime .datetime .now ().strftime ('%Y-%m-%d_%H-%M' )
583
+ figure_path = os .path .join (prefix , f'input_images_{ now_string } { IMAGE_EXT } ' )
584
+ if not os .path .exists (os .path .dirname (figure_path )):
585
+ os .makedirs (os .path .dirname (figure_path ))
586
+ plt .savefig (figure_path , bbox_inches = "tight" )
587
+ plt .close ()
588
+ return generated_images
564
589
565
590
def in_paint (self , images_original , masks , diffusion_steps = 64 , num_rows = 3 , num_cols = 6 ):
566
591
images = images_original [0 ][self .tensor_map .input_name ()]
@@ -612,7 +637,7 @@ class DiffusionController(keras.Model):
612
637
def __init__ (
613
638
self , tensor_map , output_maps , batch_size , widths , block_depth , conv_x , control_size ,
614
639
attention_start , attention_heads , attention_modulo , diffusion_loss , sigmoid_beta , condition_strategy ,
615
- supervisor = None ,
640
+ inspect_model , supervisor = None , supervision_scalar = 0.01 ,
616
641
):
617
642
super ().__init__ ()
618
643
@@ -627,22 +652,28 @@ def __init__(
627
652
self .use_sigmoid_loss = diffusion_loss == 'sigmoid'
628
653
self .beta = sigmoid_beta
629
654
self .supervisor = supervisor
655
+ self .supervision_scalar = supervision_scalar
656
+ self .inspect_model = inspect_model
630
657
631
658
632
659
def compile (self , ** kwargs ):
633
660
super ().compile (** kwargs )
634
-
635
661
self .noise_loss_tracker = keras .metrics .Mean (name = "n_loss" )
636
662
self .image_loss_tracker = keras .metrics .Mean (name = "i_loss" )
663
+ self .mse_metric = tf .keras .metrics .MeanSquaredError (name = "mse" )
664
+ self .mae_metric = tf .keras .metrics .MeanAbsoluteError (name = "mae" )
637
665
if self .supervisor is not None :
638
666
self .supervised_loss_tracker = keras .metrics .Mean (name = "supervised_loss" )
639
- # self.kid = KID(name = "kid", input_shape = self.tensor_map.shape)
667
+ if self .input_map .axes () == 3 and self .inspect_model :
668
+ self .kid = KernelInceptionDistance (name = "kid" , input_shape = self .input_map .shape , kernel_image_size = 299 )
640
669
641
670
@property
642
671
def metrics (self ):
643
- m = [self .noise_loss_tracker , self .image_loss_tracker ]
672
+ m = [self .noise_loss_tracker , self .image_loss_tracker , self . mse_metric , self . mae_metric ]
644
673
if self .supervisor is not None :
645
674
m .append (self .supervised_loss_tracker )
675
+ if self .input_map .axes () == 3 and self .inspect_model :
676
+ m .append (self .kid )
646
677
return m
647
678
648
679
def denormalize (self , images ):
@@ -764,12 +795,15 @@ def train_step(self, batch):
764
795
weight = tf .math .sigmoid (self .beta - lambda_t )
765
796
noise_loss = weight * noise_loss
766
797
if self .supervisor is not None :
767
- loss_fn = tf .keras .losses .MeanSquaredError ()
798
+ if self .output_maps [0 ].is_categorical ():
799
+ loss_fn = tf .keras .losses .CategoricalCrossentropy ()
800
+ else :
801
+ loss_fn = tf .keras .losses .MeanSquaredError ()
768
802
supervised_preds = self .supervisor (pred_images , training = True )
769
803
supervised_loss = loss_fn (batch [1 ][self .output_maps [0 ].output_name ()], supervised_preds )
770
804
self .supervised_loss_tracker .update_state (supervised_loss )
771
805
# Combine losses: add noise_loss and supervised_loss
772
- noise_loss += 0.01 * supervised_loss
806
+ noise_loss += self . supervision_scalar * supervised_loss
773
807
774
808
# Gradients for self.supervised_model
775
809
supervised_gradients = tape .gradient (supervised_loss , self .supervisor .trainable_weights )
@@ -780,50 +814,15 @@ def train_step(self, batch):
780
814
781
815
self .noise_loss_tracker .update_state (noise_loss )
782
816
self .image_loss_tracker .update_state (image_loss )
817
+ self .mse_metric .update_state (noises , pred_noises )
818
+ self .mae_metric .update_state (noises , pred_noises )
783
819
784
820
# track the exponential moving averages of weights
785
821
for weight , ema_weight in zip (self .network .weights , self .ema_network .weights ):
786
822
ema_weight .assign (ema * ema_weight + (1 - ema ) * weight )
787
823
788
824
# KID is not measured during the training phase for computational efficiency
789
- return {m .name : m .result () for m in self .metrics [:- 1 ]}
790
-
791
- # def call(self, inputs):
792
- # # normalize images to have standard deviation of 1, like the noises
793
- # images = inputs[self.input_map.input_name()]
794
- # self.normalizer.update_state(images)
795
- # images = self.normalizer(images, training=False)
796
-
797
- # control_embed = self.control_embed_model(inputs)
798
-
799
- # noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape)
800
-
801
- # # sample uniform random diffusion times
802
- # diffusion_times = tf.random.uniform(
803
- # shape=[self.batch_size, ] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0
804
- # )
805
- # noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
806
- # # mix the images with noises accordingly
807
- # noisy_images = signal_rates * images + noise_rates * noises
808
-
809
- # # use the network to separate noisy images to their components
810
- # pred_noises, pred_images = self.denoise(
811
- # control_embed, noisy_images, noise_rates, signal_rates, training=False
812
- # )
813
-
814
- # noise_loss = self.loss(noises, pred_noises)
815
- # image_loss = self.loss(images, pred_images)
816
-
817
- # self.image_loss_tracker.update_state(image_loss)
818
- # self.noise_loss_tracker.update_state(noise_loss)
819
-
820
- # # measure KID between real and generated images
821
- # # this is computationally demanding, kid_diffusion_steps has to be small
822
- # images = self.denormalize(images)
823
- # generated_images = self.generate(
824
- # control_embed, num_images=self.batch_size, diffusion_steps=20
825
- # )
826
- # return generated_images
825
+ return {m .name : m .result () for m in self .metrics }
827
826
828
827
def test_step (self , batch ):
829
828
# normalize images to have standard deviation of 1, like the noises
@@ -859,26 +858,33 @@ def test_step(self, batch):
859
858
weight = tf .math .sigmoid (self .beta - lambda_t )
860
859
noise_loss = weight * noise_loss
861
860
if self .supervisor is not None :
862
- loss_fn = tf .keras .losses .MeanSquaredError ()
861
+ if self .output_maps [0 ].is_categorical ():
862
+ loss_fn = tf .keras .losses .CategoricalCrossentropy ()
863
+ else :
864
+ loss_fn = tf .keras .losses .MeanSquaredError ()
863
865
supervised_preds = self .supervisor (pred_images , training = True )
864
866
supervised_loss = loss_fn (batch [1 ][self .output_maps [0 ].output_name ()], supervised_preds )
865
867
self .supervised_loss_tracker .update_state (supervised_loss )
866
868
# Combine losses: add noise_loss and supervised_loss
867
- noise_loss += 0.01 * supervised_loss
869
+ noise_loss += self . supervision_scalar * supervised_loss
868
870
869
871
self .image_loss_tracker .update_state (image_loss )
870
872
self .noise_loss_tracker .update_state (noise_loss )
873
+ self .mse_metric .update_state (noises , pred_noises )
874
+ self .mae_metric .update_state (noises , pred_noises )
871
875
872
876
# measure KID between real and generated images
873
877
# this is computationally demanding, kid_diffusion_steps has to be small
874
- images = self .denormalize (images )
875
- generated_images = self .generate (
876
- control_embed , num_images = self .batch_size , diffusion_steps = 20 ,
877
- )
878
- # self.kid.update_state(images, generated_images)
878
+ if self .input_map .axes () == 3 and self .inspect_model :
879
+ images = self .denormalize (images )
880
+ generated_images = self .generate (control_embed ,
881
+ num_images = self .batch_size , diffusion_steps = 20
882
+ )
883
+ self .kid .update_state (images , generated_images )
879
884
880
885
return {m .name : m .result () for m in self .metrics }
881
886
887
+
882
888
def plot_images (self , epoch = None , logs = None , num_rows = 1 , num_cols = 4 , reseed = None , prefix = './figures/' ):
883
889
control_batch = {}
884
890
for cm in self .output_maps :
@@ -912,7 +918,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None
912
918
913
919
def plot_reconstructions (
914
920
self , batch , diffusion_amount = 0 ,
915
- epoch = None , logs = None , num_rows = 4 , num_cols = 4 ,
921
+ epoch = None , logs = None , num_rows = 4 , num_cols = 4 , prefix = './figures/' ,
916
922
):
917
923
images = batch [0 ][self .input_map .input_name ()]
918
924
self .normalizer .update_state (images )
@@ -937,8 +943,28 @@ def plot_reconstructions(
937
943
plt .imshow (generated_images [index ], cmap = 'gray' )
938
944
plt .axis ("off" )
939
945
plt .tight_layout ()
940
- plt .show ()
946
+ now_string = datetime .datetime .now ().strftime ('%Y-%m-%d_%H-%M' )
947
+ figure_path = os .path .join (prefix , f'diffusion_image_reconstructions_{ now_string } { IMAGE_EXT } ' )
948
+ if not os .path .exists (os .path .dirname (figure_path )):
949
+ os .makedirs (os .path .dirname (figure_path ))
950
+ plt .savefig (figure_path , bbox_inches = "tight" )
941
951
plt .close ()
952
+ plt .figure (figsize = (num_cols * 2.0 , num_rows * 2.0 ), dpi = 300 )
953
+ for row in range (num_rows ):
954
+ for col in range (num_cols ):
955
+ index = row * num_cols + col
956
+ plt .subplot (num_rows , num_cols , index + 1 )
957
+ plt .imshow (images [index ], cmap = 'gray' )
958
+ plt .axis ("off" )
959
+ plt .tight_layout ()
960
+ now_string = datetime .datetime .now ().strftime ('%Y-%m-%d_%H-%M' )
961
+ figure_path = os .path .join (prefix , f'input_images_{ now_string } { IMAGE_EXT } ' )
962
+ if not os .path .exists (os .path .dirname (figure_path )):
963
+ os .makedirs (os .path .dirname (figure_path ))
964
+ plt .savefig (figure_path , bbox_inches = "tight" )
965
+ plt .close ()
966
+ return generated_images
967
+
942
968
943
969
def control_plot_images (
944
970
self , control_batch , epoch = None , logs = None , num_rows = 2 , num_cols = 8 , reseed = None ,
0 commit comments