Skip to content

Commit

Permalink
sigmoid loss unconditioned
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Dec 19, 2024
1 parent baf60a3 commit 2df321c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
10 changes: 8 additions & 2 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,10 @@ def reverse_diffusion(self, initial_noise, diffusion_steps):

return pred_images

def generate(self, num_images, diffusion_steps):
def generate(self, num_images, diffusion_steps, reseed=None):
# noise -> images -> denormalized images
if reseed is not None:
tf.random.set_seed(reseed)
initial_noise = tf.random.normal(shape=(num_images,) + self.tensor_map.shape)
generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
generated_images = self.denormalize(generated_images)
Expand Down Expand Up @@ -474,11 +476,12 @@ def test_step(self, images_original):

return {m.name: m.result() for m in self.metrics}

def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, prefix='./figures/'):
def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, reseed=None, prefix='./figures/'):
# plot random generated images for visual evaluation of generation quality
generated_images = self.generate(
num_images=num_rows * num_cols,
diffusion_steps=plot_diffusion_steps,
reseed=reseed,
)

plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
Expand All @@ -497,12 +500,14 @@ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, prefix='./f
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()

def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, prefix='./figures/'):
# plot random generated images for visual evaluation of generation quality
generated_images = self.generate(
num_images=max(self.batch_size, num_rows * num_cols),
diffusion_steps=plot_diffusion_steps,
reseed=reseed,
)

plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
Expand All @@ -518,6 +523,7 @@ def plot_ecgs(self, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()

def plot_reconstructions(
self, images_original, diffusion_amount=0,
Expand Down
22 changes: 15 additions & 7 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def train_diffusion_model(args):
for k in batch[1]:
logging.info(f"label {k} {batch[1][k].shape}")
checkpoint_path = f"{args.output_folder}{args.id}/{args.id}"
if os.path.exists(checkpoint_path+'.index'):
model.load_weights(checkpoint_path)
logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}')
else:
logging.info(f'No checkpoint at: {checkpoint_path}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
Expand All @@ -132,8 +137,11 @@ def train_diffusion_model(args):
save_best_only=True,
)

callbacks = [checkpoint_callback]

# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(feature_batch)

if args.inspect_model:
model.network.summary(print_fn=logging.info, expand_nested=True)
tf.keras.utils.plot_model(
Expand All @@ -148,20 +156,20 @@ def train_diffusion_model(args):
layer_range=None,
show_layer_activations=False,
)

if os.path.exists(checkpoint_path+'.index'):
model.load_weights(checkpoint_path)
logging.info(f'Loaded weights from model checkpoint at: {checkpoint_path}')
else:
logging.info(f'No checkpoint at: {checkpoint_path}')
prefix_value = f'{args.output_folder}{args.id}/learning_generations/'
if model.tensor_map.axes() == 2:
plot_partial = partial(model.plot_ecgs, reseed=args.random_seed, prefix=prefix_value)
else:
plot_partial = partial(model.plot_images, reseed=args.random_seed, prefix=prefix_value)
callbacks.append(keras.callbacks.LambdaCallback(on_epoch_end=plot_partial))

history = model.fit(
generate_train,
steps_per_epoch=args.training_steps,
epochs=args.epochs,
validation_data=generate_valid,
validation_steps=args.validation_steps,
callbacks=[checkpoint_callback],
callbacks=callbacks,
)
model.load_weights(checkpoint_path)
#diffusion_model.compile(optimizer='adam', loss='mse')
Expand Down

0 comments on commit 2df321c

Please sign in to comment.