Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 14, 2025
1 parent fc30746 commit efa7527
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
9 changes: 5 additions & 4 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_control_embed_model(output_maps, control_size):


class DiffusionModel(keras.Model):
def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta):
def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, diffusion_loss, sigmoid_beta, inspect_model):
super().__init__()

self.tensor_map = tensor_map
Expand All @@ -294,6 +294,7 @@ def __init__(self, tensor_map, batch_size, widths, block_depth, kernel_size, dif
self.ema_network = keras.models.clone_model(self.network)
self.use_sigmoid_loss = diffusion_loss == 'sigmoid'
self.beta = sigmoid_beta
self.inspect_model = inspect_model

def can_apply(self):
return self.tensor_map.axes() > 1
Expand All @@ -305,13 +306,13 @@ def compile(self, **kwargs):
self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
self.mse_metric = tf.keras.metrics.MeanSquaredError(name="mse")
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
if self.tensor_map.axes() == 3:
if self.tensor_map.axes() == 3 and self.inspect_model:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=75)

@property
def metrics(self):
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
if self.tensor_map.axes() == 3:
if self.tensor_map.axes() == 3 and self.inspect_model:
m.append(self.kid)
return m

Expand Down Expand Up @@ -479,7 +480,7 @@ def test_step(self, images_original):

# measure KID between real and generated images
# this is computationally demanding, kid_diffusion_steps has to be small
if self.tensor_map.axes() == 3:
if self.tensor_map.axes() == 3 and self.inspect_model:
images = self.denormalize(images)
generated_images = self.generate(
num_images=self.batch_size, diffusion_steps=20
Expand Down
10 changes: 8 additions & 2 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _get_callbacks(
def train_diffusion_model(args):
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
model = DiffusionModel(args.tensor_maps_in[0], args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.diffusion_loss, args.sigmoid_beta)
args.diffusion_loss, args.sigmoid_beta, args.inspect_model)

model.compile(
optimizer=tfa.optimizers.AdamW(
Expand Down Expand Up @@ -173,9 +173,15 @@ def train_diffusion_model(args):
callbacks=callbacks,
)
model.load_weights(checkpoint_path)
#diffusion_model.compile(optimizer='adam', loss='mse')
plot_metric_history(history, args.training_steps, args.id, os.path.dirname(checkpoint_path))
if args.inspect_model:
metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True)
logging.info(f'Test metrics: {metrics}')

data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1)
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/')
image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/')
if model.tensor_map.axes() == 2:
model.plot_ecgs(num_rows=4, prefix=os.path.dirname(checkpoint_path))
else:
Expand Down

0 comments on commit efa7527

Please sign in to comment.