Skip to content

Commit 1f00de9

Browse files
author
sana
committed
Merge branch 'sf_attn' of github.com:broadinstitute/ml4h into sf_attn
2 parents ec8c5c8 + 88c59bf commit 1f00de9

File tree

7 files changed

+149
-78
lines changed

7 files changed

+149
-78
lines changed

ml4h/arguments.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ def parse_args():
233233
'2 means every other residual block, 3 would mean every third.',
234234
)
235235
parser.add_argument(
236-
'--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'],
236+
'--diffusion_condition_strategy', default='cross_attention',
237+
choices=['cross_attention', 'concat', 'film'],
237238
help='For diffusion models, this controls conditional embeddings are integrated into the U-NET',
238239
)
239240
parser.add_argument(
@@ -244,6 +245,10 @@ def parse_args():
244245
'--sigmoid_beta', default=-3, type=float,
245246
help='Beta to use with sigmoid loss for diffusion models.',
246247
)
248+
parser.add_argument(
249+
'--supervision_scalar', default=0.01, type=float,
250+
help='For `train_diffusion_supervise` mode, this weights the supervision loss from phenotype prediction on denoised data.',
251+
)
247252
parser.add_argument(
248253
'--transformer_size', default=32, type=int,
249254
help='Number of output neurons in Transformer encoders and decoders, '

ml4h/models/diffusion_blocks.py

Lines changed: 93 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s
259259

260260
for i, width in enumerate(reversed(widths[:-1])):
261261
if attention_modulo > 1 and ((len(widths) - 1) - i) % attention_modulo == 0:
262-
if len(input_shape) > 2:
263-
c2 = upsample(size=x.shape[1:-1]*2)(control[control_idxs])
262+
if len(input_shape) == 3:
263+
c2 = upsample(size=(x.shape[1]*2, x.shape[2]*2))(control[control_idxs])
264264
else:
265265
c2 = upsample(size=x.shape[-2]*2)(control[control_idxs])
266266
x = up_block_control(width, block_depth, conv, upsample,
@@ -284,7 +284,7 @@ def get_control_embed_model(output_maps, control_size):
284284

285285

286286
class DiffusionModel(keras.Model):
287-
def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta):
287+
def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta, inspect_model):
288288
super().__init__()
289289

290290
self.tensor_map = tensor_map
@@ -294,6 +294,7 @@ def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, dif
294294
self.ema_network = keras.models.clone_model(self.network)
295295
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
296296
self.beta = sigmoid_beta
297+
self.inspect_model = inspect_model
297298

298299
def can_apply(self):
299300
return self.tensor_map.axes() > 1
@@ -303,13 +304,15 @@ def compile(self, **kwargs):
303304

304305
self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
305306
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
306-
if self.tensor_map.axes() == 3:
307-
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75)
307+
self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse")
308+
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
309+
if self.tensor_map.axes() == 3 and self.inspect_model:
310+
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=299)
308311

309312
@property
310313
def metrics(self):
311-
m = [self.noise_loss_tracker, self.image_loss_tracker]
312-
if self.tensor_map.axes() == 3:
314+
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
315+
if self.tensor_map.axes() == 3 and self.inspect_model:
313316
m.append(self.kid)
314317
return m
315318

@@ -428,13 +431,15 @@ def train_step(self, images_original):
428431

429432
self.noise_loss_tracker.update_state(noise_loss)
430433
self.image_loss_tracker.update_state(image_loss)
434+
self.mse_metric.update_state(noises, pred_noises)
435+
self.mae_metric.update_state(noises, pred_noises)
431436

432437
# track the exponential moving averages of weights
433438
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
434439
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
435440

436441
# KID is not measured during the training phase for computational efficiency
437-
return {m.name: m.result() for m in self.metrics[:-1]}
442+
return {m.name: m.result() for m in self.metrics}
438443

439444
def test_step(self, images_original):
440445
# normalize images to have standard deviation of 1, like the noises
@@ -470,10 +475,12 @@ def test_step(self, images_original):
470475

471476
self.image_loss_tracker.update_state(image_loss)
472477
self.noise_loss_tracker.update_state(noise_loss)
478+
self.mse_metric.update_state(noises, pred_noises)
479+
self.mae_metric.update_state(noises, pred_noises)
473480

474481
# measure KID between real and generated images
475482
# this is computationally demanding, kid_diffusion_steps has to be small
476-
if self.tensor_map.axes() == 3:
483+
if self.tensor_map.axes() == 3 and self.inspect_model:
477484
images = self.denormalize(images)
478485
generated_images = self.generate(
479486
num_images=self.batch_size, diffusion_steps=20
@@ -534,15 +541,14 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None,
534541
plt.close()
535542

536543
def plot_reconstructions(
537-
self, images_original, diffusion_amount=0,
538-
epoch=None, logs=None, num_rows=3, num_cols=6,
544+
self, images_original, diffusion_amount=0, epoch=None, logs=None, num_rows=2, num_cols=2, prefix='./figures/',
539545
):
540546
images = images_original[0][self.tensor_map.input_name()]
541547
self.normalizer.update_state(images)
542548
images = self.normalizer(images, training=False)
543-
noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape)
549+
noises = tf.random.normal(shape=(num_rows*num_cols,) + self.tensor_map.shape)
544550

545-
diffusion_times = diffusion_amount * tf.ones(shape=[self.batch_size] + [1] * self.tensor_map.axes())
551+
diffusion_times = diffusion_amount * tf.ones(shape=[num_rows*num_cols] + [1] * self.tensor_map.axes())
546552
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
547553
# mix the images with noises accordingly
548554
noisy_images = signal_rates * images + noise_rates * noises
@@ -559,8 +565,27 @@ def plot_reconstructions(
559565
plt.imshow(generated_images[index], cmap='gray')
560566
plt.axis("off")
561567
plt.tight_layout()
562-
plt.show()
568+
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
569+
figure_path = os.path.join(prefix, f'diffusion_reconstructions_{now_string}{IMAGE_EXT}')
570+
if not os.path.exists(os.path.dirname(figure_path)):
571+
os.makedirs(os.path.dirname(figure_path))
572+
plt.savefig(figure_path, bbox_inches="tight")
563573
plt.close()
574+
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
575+
for row in range(num_rows):
576+
for col in range(num_cols):
577+
index = row * num_cols + col
578+
plt.subplot(num_rows, num_cols, index + 1)
579+
plt.imshow(images[index], cmap='gray')
580+
plt.axis("off")
581+
plt.tight_layout()
582+
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
583+
figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}')
584+
if not os.path.exists(os.path.dirname(figure_path)):
585+
os.makedirs(os.path.dirname(figure_path))
586+
plt.savefig(figure_path, bbox_inches="tight")
587+
plt.close()
588+
return generated_images
564589

565590
def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_cols=6):
566591
images = images_original[0][self.tensor_map.input_name()]
@@ -612,7 +637,7 @@ class DiffusionController(keras.Model):
612637
def __init__(
613638
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
614639
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
615-
supervisor = None,
640+
inspect_model, supervisor = None, supervision_scalar = 0.01,
616641
):
617642
super().__init__()
618643

@@ -627,22 +652,28 @@ def __init__(
627652
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
628653
self.beta = sigmoid_beta
629654
self.supervisor = supervisor
655+
self.supervision_scalar = supervision_scalar
656+
self.inspect_model = inspect_model
630657

631658

632659
def compile(self, **kwargs):
633660
super().compile(**kwargs)
634-
635661
self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
636662
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
663+
self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse")
664+
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
637665
if self.supervisor is not None:
638666
self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss")
639-
# self.kid = KID(name = "kid", input_shape = self.tensor_map.shape)
667+
if self.input_map.axes() == 3 and self.inspect_model:
668+
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299)
640669

641670
@property
642671
def metrics(self):
643-
m = [self.noise_loss_tracker, self.image_loss_tracker]
672+
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
644673
if self.supervisor is not None:
645674
m.append(self.supervised_loss_tracker)
675+
if self.input_map.axes() == 3 and self.inspect_model:
676+
m.append(self.kid)
646677
return m
647678

648679
def denormalize(self, images):
@@ -764,12 +795,15 @@ def train_step(self, batch):
764795
weight = tf.math.sigmoid(self.beta - lambda_t)
765796
noise_loss = weight * noise_loss
766797
if self.supervisor is not None:
767-
loss_fn = tf.keras.losses.MeanSquaredError()
798+
if self.output_maps[0].is_categorical():
799+
loss_fn = tf.keras.losses.CategoricalCrossentropy()
800+
else:
801+
loss_fn = tf.keras.losses.MeanSquaredError()
768802
supervised_preds = self.supervisor(pred_images, training=True)
769803
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
770804
self.supervised_loss_tracker.update_state(supervised_loss)
771805
# Combine losses: add noise_loss and supervised_loss
772-
noise_loss += 0.01 * supervised_loss
806+
noise_loss += self.supervision_scalar * supervised_loss
773807

774808
# Gradients for self.supervised_model
775809
supervised_gradients = tape.gradient(supervised_loss, self.supervisor.trainable_weights)
@@ -780,50 +814,15 @@ def train_step(self, batch):
780814

781815
self.noise_loss_tracker.update_state(noise_loss)
782816
self.image_loss_tracker.update_state(image_loss)
817+
self.mse_metric.update_state(noises, pred_noises)
818+
self.mae_metric.update_state(noises, pred_noises)
783819

784820
# track the exponential moving averages of weights
785821
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
786822
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
787823

788824
# KID is not measured during the training phase for computational efficiency
789-
return {m.name: m.result() for m in self.metrics[:-1]}
790-
791-
# def call(self, inputs):
792-
# # normalize images to have standard deviation of 1, like the noises
793-
# images = inputs[self.input_map.input_name()]
794-
# self.normalizer.update_state(images)
795-
# images = self.normalizer(images, training=False)
796-
797-
# control_embed = self.control_embed_model(inputs)
798-
799-
# noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape)
800-
801-
# # sample uniform random diffusion times
802-
# diffusion_times = tf.random.uniform(
803-
# shape=[self.batch_size, ] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0
804-
# )
805-
# noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
806-
# # mix the images with noises accordingly
807-
# noisy_images = signal_rates * images + noise_rates * noises
808-
809-
# # use the network to separate noisy images to their components
810-
# pred_noises, pred_images = self.denoise(
811-
# control_embed, noisy_images, noise_rates, signal_rates, training=False
812-
# )
813-
814-
# noise_loss = self.loss(noises, pred_noises)
815-
# image_loss = self.loss(images, pred_images)
816-
817-
# self.image_loss_tracker.update_state(image_loss)
818-
# self.noise_loss_tracker.update_state(noise_loss)
819-
820-
# # measure KID between real and generated images
821-
# # this is computationally demanding, kid_diffusion_steps has to be small
822-
# images = self.denormalize(images)
823-
# generated_images = self.generate(
824-
# control_embed, num_images=self.batch_size, diffusion_steps=20
825-
# )
826-
# return generated_images
825+
return {m.name: m.result() for m in self.metrics}
827826

828827
def test_step(self, batch):
829828
# normalize images to have standard deviation of 1, like the noises
@@ -859,26 +858,33 @@ def test_step(self, batch):
859858
weight = tf.math.sigmoid(self.beta - lambda_t)
860859
noise_loss = weight * noise_loss
861860
if self.supervisor is not None:
862-
loss_fn = tf.keras.losses.MeanSquaredError()
861+
if self.output_maps[0].is_categorical():
862+
loss_fn = tf.keras.losses.CategoricalCrossentropy()
863+
else:
864+
loss_fn = tf.keras.losses.MeanSquaredError()
863865
supervised_preds = self.supervisor(pred_images, training=True)
864866
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
865867
self.supervised_loss_tracker.update_state(supervised_loss)
866868
# Combine losses: add noise_loss and supervised_loss
867-
noise_loss += 0.01*supervised_loss
869+
noise_loss += self.supervision_scalar*supervised_loss
868870

869871
self.image_loss_tracker.update_state(image_loss)
870872
self.noise_loss_tracker.update_state(noise_loss)
873+
self.mse_metric.update_state(noises, pred_noises)
874+
self.mae_metric.update_state(noises, pred_noises)
871875

872876
# measure KID between real and generated images
873877
# this is computationally demanding, kid_diffusion_steps has to be small
874-
images = self.denormalize(images)
875-
generated_images = self.generate(
876-
control_embed, num_images=self.batch_size, diffusion_steps=20,
877-
)
878-
# self.kid.update_state(images, generated_images)
878+
if self.input_map.axes() == 3 and self.inspect_model:
879+
images = self.denormalize(images)
880+
generated_images = self.generate(control_embed,
881+
num_images=self.batch_size, diffusion_steps=20
882+
)
883+
self.kid.update_state(images, generated_images)
879884

880885
return {m.name: m.result() for m in self.metrics}
881886

887+
882888
def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'):
883889
control_batch = {}
884890
for cm in self.output_maps:
@@ -912,7 +918,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None
912918

913919
def plot_reconstructions(
914920
self, batch, diffusion_amount=0,
915-
epoch=None, logs=None, num_rows=4, num_cols=4,
921+
epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/',
916922
):
917923
images = batch[0][self.input_map.input_name()]
918924
self.normalizer.update_state(images)
@@ -937,8 +943,28 @@ def plot_reconstructions(
937943
plt.imshow(generated_images[index], cmap='gray')
938944
plt.axis("off")
939945
plt.tight_layout()
940-
plt.show()
946+
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
947+
figure_path = os.path.join(prefix, f'diffusion_image_reconstructions_{now_string}{IMAGE_EXT}')
948+
if not os.path.exists(os.path.dirname(figure_path)):
949+
os.makedirs(os.path.dirname(figure_path))
950+
plt.savefig(figure_path, bbox_inches="tight")
941951
plt.close()
952+
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
953+
for row in range(num_rows):
954+
for col in range(num_cols):
955+
index = row * num_cols + col
956+
plt.subplot(num_rows, num_cols, index + 1)
957+
plt.imshow(images[index], cmap='gray')
958+
plt.axis("off")
959+
plt.tight_layout()
960+
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
961+
figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}')
962+
if not os.path.exists(os.path.dirname(figure_path)):
963+
os.makedirs(os.path.dirname(figure_path))
964+
plt.savefig(figure_path, bbox_inches="tight")
965+
plt.close()
966+
return generated_images
967+
942968

943969
def control_plot_images(
944970
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,

0 commit comments

Comments
 (0)