Skip to content

Commit

Permalink
Merge branch 'sf_attn' of github.com:broadinstitute/ml4h into sf_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
sana committed Jan 17, 2025
2 parents ec8c5c8 + 88c59bf commit 1f00de9
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 78 deletions.
7 changes: 6 additions & 1 deletion ml4h/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def parse_args():
'2 means every other residual block, 3 would mean every third.',
)
parser.add_argument(
'--diffusion_condition_strategy', default='concat', choices=['cross_attention', 'concat', 'film'],
'--diffusion_condition_strategy', default='cross_attention',
choices=['cross_attention', 'concat', 'film'],
help='For diffusion models, this controls conditional embeddings are integrated into the U-NET',
)
parser.add_argument(
Expand All @@ -244,6 +245,10 @@ def parse_args():
'--sigmoid_beta', default=-3, type=float,
help='Beta to use with sigmoid loss for diffusion models.',
)
parser.add_argument(
'--supervision_scalar', default=0.01, type=float,
help='For `train_diffusion_supervise` mode, this weights the supervision loss from phenotype prediction on denoised data.',
)
parser.add_argument(
'--transformer_size', default=32, type=int,
help='Number of output neurons in Transformer encoders and decoders, '
Expand Down
160 changes: 93 additions & 67 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def get_control_network(input_shape, widths, block_depth, kernel_size, control_s

for i, width in enumerate(reversed(widths[:-1])):
if attention_modulo > 1 and ((len(widths) - 1) - i) % attention_modulo == 0:
if len(input_shape) > 2:
c2 = upsample(size=x.shape[1:-1]*2)(control[control_idxs])
if len(input_shape) == 3:
c2 = upsample(size=(x.shape[1]*2, x.shape[2]*2))(control[control_idxs])
else:
c2 = upsample(size=x.shape[-2]*2)(control[control_idxs])
x = up_block_control(width, block_depth, conv, upsample,
Expand All @@ -284,7 +284,7 @@ def get_control_embed_model(output_maps, control_size):


class DiffusionModel(keras.Model):
def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta):
def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta, inspect_model):
super().__init__()

self.tensor_map = tensor_map
Expand All @@ -294,6 +294,7 @@ def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, dif
self.ema_network = keras.models.clone_model(self.network)
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
self.beta = sigmoid_beta
self.inspect_model = inspect_model

def can_apply(self):
return self.tensor_map.axes() > 1
Expand All @@ -303,13 +304,15 @@ def compile(self, **kwargs):

self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
if self.tensor_map.axes() == 3:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75)
self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse")
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
if self.tensor_map.axes() == 3 and self.inspect_model:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=299)

@property
def metrics(self):
m = [self.noise_loss_tracker, self.image_loss_tracker]
if self.tensor_map.axes() == 3:
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
if self.tensor_map.axes() == 3 and self.inspect_model:
m.append(self.kid)
return m

Expand Down Expand Up @@ -428,13 +431,15 @@ def train_step(self, images_original):

self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
self.mse_metric.update_state(noises, pred_noises)
self.mae_metric.update_state(noises, pred_noises)

# track the exponential moving averages of weights
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

# KID is not measured during the training phase for computational efficiency
return {m.name: m.result() for m in self.metrics[:-1]}
return {m.name: m.result() for m in self.metrics}

def test_step(self, images_original):
# normalize images to have standard deviation of 1, like the noises
Expand Down Expand Up @@ -470,10 +475,12 @@ def test_step(self, images_original):

self.image_loss_tracker.update_state(image_loss)
self.noise_loss_tracker.update_state(noise_loss)
self.mse_metric.update_state(noises, pred_noises)
self.mae_metric.update_state(noises, pred_noises)

# measure KID between real and generated images
# this is computationally demanding, kid_diffusion_steps has to be small
if self.tensor_map.axes() == 3:
if self.tensor_map.axes() == 3 and self.inspect_model:
images = self.denormalize(images)
generated_images = self.generate(
num_images=self.batch_size, diffusion_steps=20
Expand Down Expand Up @@ -534,15 +541,14 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None,
plt.close()

def plot_reconstructions(
self, images_original, diffusion_amount=0,
epoch=None, logs=None, num_rows=3, num_cols=6,
self, images_original, diffusion_amount=0, epoch=None, logs=None, num_rows=2, num_cols=2, prefix='./figures/',
):
images = images_original[0][self.tensor_map.input_name()]
self.normalizer.update_state(images)
images = self.normalizer(images, training=False)
noises = tf.random.normal(shape=(self.batch_size,) + self.tensor_map.shape)
noises = tf.random.normal(shape=(num_rows*num_cols,) + self.tensor_map.shape)

diffusion_times = diffusion_amount * tf.ones(shape=[self.batch_size] + [1] * self.tensor_map.axes())
diffusion_times = diffusion_amount * tf.ones(shape=[num_rows*num_cols] + [1] * self.tensor_map.axes())
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# mix the images with noises accordingly
noisy_images = signal_rates * images + noise_rates * noises
Expand All @@ -559,8 +565,27 @@ def plot_reconstructions(
plt.imshow(generated_images[index], cmap='gray')
plt.axis("off")
plt.tight_layout()
plt.show()
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
figure_path = os.path.join(prefix, f'diffusion_reconstructions_{now_string}{IMAGE_EXT}')
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(images[index], cmap='gray')
plt.axis("off")
plt.tight_layout()
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}')
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()
return generated_images

def in_paint(self, images_original, masks, diffusion_steps=64, num_rows=3, num_cols=6):
images = images_original[0][self.tensor_map.input_name()]
Expand Down Expand Up @@ -612,7 +637,7 @@ class DiffusionController(keras.Model):
def __init__(
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
supervisor = None,
inspect_model, supervisor = None, supervision_scalar = 0.01,
):
super().__init__()

Expand All @@ -627,22 +652,28 @@ def __init__(
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
self.beta = sigmoid_beta
self.supervisor = supervisor
self.supervision_scalar = supervision_scalar
self.inspect_model = inspect_model


def compile(self, **kwargs):
super().compile(**kwargs)

self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse")
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
if self.supervisor is not None:
self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss")
# self.kid = KID(name = "kid", input_shape = self.tensor_map.shape)
if self.input_map.axes() == 3 and self.inspect_model:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299)

@property
def metrics(self):
m = [self.noise_loss_tracker, self.image_loss_tracker]
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
if self.supervisor is not None:
m.append(self.supervised_loss_tracker)
if self.input_map.axes() == 3 and self.inspect_model:
m.append(self.kid)
return m

def denormalize(self, images):
Expand Down Expand Up @@ -764,12 +795,15 @@ def train_step(self, batch):
weight = tf.math.sigmoid(self.beta - lambda_t)
noise_loss = weight * noise_loss
if self.supervisor is not None:
loss_fn = tf.keras.losses.MeanSquaredError()
if self.output_maps[0].is_categorical():
loss_fn = tf.keras.losses.CategoricalCrossentropy()
else:
loss_fn = tf.keras.losses.MeanSquaredError()
supervised_preds = self.supervisor(pred_images, training=True)
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
self.supervised_loss_tracker.update_state(supervised_loss)
# Combine losses: add noise_loss and supervised_loss
noise_loss += 0.01 * supervised_loss
noise_loss += self.supervision_scalar * supervised_loss

# Gradients for self.supervised_model
supervised_gradients = tape.gradient(supervised_loss, self.supervisor.trainable_weights)
Expand All @@ -780,50 +814,15 @@ def train_step(self, batch):

self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
self.mse_metric.update_state(noises, pred_noises)
self.mae_metric.update_state(noises, pred_noises)

# track the exponential moving averages of weights
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

# KID is not measured during the training phase for computational efficiency
return {m.name: m.result() for m in self.metrics[:-1]}

# def call(self, inputs):
# # normalize images to have standard deviation of 1, like the noises
# images = inputs[self.input_map.input_name()]
# self.normalizer.update_state(images)
# images = self.normalizer(images, training=False)

# control_embed = self.control_embed_model(inputs)

# noises = tf.random.normal(shape=(self.batch_size,) + self.input_map.shape)

# # sample uniform random diffusion times
# diffusion_times = tf.random.uniform(
# shape=[self.batch_size, ] + [1] * self.input_map.axes(), minval=0.0, maxval=1.0
# )
# noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# # mix the images with noises accordingly
# noisy_images = signal_rates * images + noise_rates * noises

# # use the network to separate noisy images to their components
# pred_noises, pred_images = self.denoise(
# control_embed, noisy_images, noise_rates, signal_rates, training=False
# )

# noise_loss = self.loss(noises, pred_noises)
# image_loss = self.loss(images, pred_images)

# self.image_loss_tracker.update_state(image_loss)
# self.noise_loss_tracker.update_state(noise_loss)

# # measure KID between real and generated images
# # this is computationally demanding, kid_diffusion_steps has to be small
# images = self.denormalize(images)
# generated_images = self.generate(
# control_embed, num_images=self.batch_size, diffusion_steps=20
# )
# return generated_images
return {m.name: m.result() for m in self.metrics}

def test_step(self, batch):
# normalize images to have standard deviation of 1, like the noises
Expand Down Expand Up @@ -859,26 +858,33 @@ def test_step(self, batch):
weight = tf.math.sigmoid(self.beta - lambda_t)
noise_loss = weight * noise_loss
if self.supervisor is not None:
loss_fn = tf.keras.losses.MeanSquaredError()
if self.output_maps[0].is_categorical():
loss_fn = tf.keras.losses.CategoricalCrossentropy()
else:
loss_fn = tf.keras.losses.MeanSquaredError()
supervised_preds = self.supervisor(pred_images, training=True)
supervised_loss = loss_fn(batch[1][self.output_maps[0].output_name()], supervised_preds)
self.supervised_loss_tracker.update_state(supervised_loss)
# Combine losses: add noise_loss and supervised_loss
noise_loss += 0.01*supervised_loss
noise_loss += self.supervision_scalar*supervised_loss

self.image_loss_tracker.update_state(image_loss)
self.noise_loss_tracker.update_state(noise_loss)
self.mse_metric.update_state(noises, pred_noises)
self.mae_metric.update_state(noises, pred_noises)

# measure KID between real and generated images
# this is computationally demanding, kid_diffusion_steps has to be small
images = self.denormalize(images)
generated_images = self.generate(
control_embed, num_images=self.batch_size, diffusion_steps=20,
)
# self.kid.update_state(images, generated_images)
if self.input_map.axes() == 3 and self.inspect_model:
images = self.denormalize(images)
generated_images = self.generate(control_embed,
num_images=self.batch_size, diffusion_steps=20
)
self.kid.update_state(images, generated_images)

return {m.name: m.result() for m in self.metrics}


def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'):
control_batch = {}
for cm in self.output_maps:
Expand Down Expand Up @@ -912,7 +918,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None

def plot_reconstructions(
self, batch, diffusion_amount=0,
epoch=None, logs=None, num_rows=4, num_cols=4,
epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/',
):
images = batch[0][self.input_map.input_name()]
self.normalizer.update_state(images)
Expand All @@ -937,8 +943,28 @@ def plot_reconstructions(
plt.imshow(generated_images[index], cmap='gray')
plt.axis("off")
plt.tight_layout()
plt.show()
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
figure_path = os.path.join(prefix, f'diffusion_image_reconstructions_{now_string}{IMAGE_EXT}')
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(images[index], cmap='gray')
plt.axis("off")
plt.tight_layout()
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}')
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()
return generated_images


def control_plot_images(
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
Expand Down
Loading

0 comments on commit 1f00de9

Please sign in to comment.