From e3a179cc9327a6edf9aafa983a6805992ed8d040 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Fri, 17 Jan 2025 13:36:35 -0500 Subject: [PATCH] condition and supervise --- ml4h/models/diffusion_blocks.py | 20 +++++++++++++------- ml4h/models/train.py | 14 +++++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 431060eb0..4d7ac2cbd 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -637,7 +637,7 @@ class DiffusionController(keras.Model): def __init__( self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size, attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy, - supervisor = None, supervision_scalar = 0.01, + inspect_model, supervisor = None, supervision_scalar = 0.01, ): super().__init__() @@ -653,6 +653,7 @@ def __init__( self.beta = sigmoid_beta self.supervisor = supervisor self.supervision_scalar = supervision_scalar + self.inspect_model = inspect_model def compile(self, **kwargs): @@ -663,13 +664,16 @@ def compile(self, **kwargs): self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae") if self.supervisor is not None: self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss") - # self.kid = KID(name = "kid", input_shape = self.tensor_map.shape) + if self.input_map.axes() == 3 and self.inspect_model: + self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299) @property def metrics(self): m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric] if self.supervisor is not None: m.append(self.supervised_loss_tracker) + if self.input_map.axes() == 3 and self.inspect_model: + m.append(self.kid) return m def denormalize(self, images): @@ -871,14 +875,16 @@ def test_step(self, batch): # measure KID between real and generated images # this is computationally demanding, kid_diffusion_steps has to be small - images = self.denormalize(images) - generated_images = self.generate( - control_embed, num_images=self.batch_size, diffusion_steps=20, - ) - # self.kid.update_state(images, generated_images) + if self.tensor_map.axes() == 3 and self.inspect_model: + images = self.denormalize(images) + generated_images = self.generate( + num_images=self.batch_size, diffusion_steps=20 + ) + self.kid.update_state(images, generated_images) return {m.name: m.result() for m in self.metrics} + def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'): control_batch = {} for cm in self.output_maps: diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 734ae02c0..eb935f9aa 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -296,13 +296,13 @@ def train_diffusion_control_model(args, supervised=False): model = DiffusionController( args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model, args.supervision_scalar, + args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model, supervised_model, args.supervision_scalar, ) else: model = DiffusionController( args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x, args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss, - args.sigmoid_beta, args.diffusion_condition_strategy, + args.inspect_model, args.sigmoid_beta, args.diffusion_condition_strategy, ) loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error @@ -385,8 +385,12 @@ def train_diffusion_control_model(args, supervised=False): metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True) logging.info(f'Test metrics: {metrics}') - data, labels, paths = big_batch_from_minibatch_generator(generate_test, args.test_steps) - preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/') + steps = 1 if args.batch_size > 3 else args.test_steps + data, labels, paths = big_batch_from_minibatch_generator(generate_test, steps) + sides = int(np.sqrt(steps*args.batch_size)) + preds = model.plot_reconstructions((data, labels), num_rows=sides, num_cols=sides, + prefix=f'{args.output_folder}/{args.id}/reconstructions/') + image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, @@ -394,7 +398,7 @@ def train_diffusion_control_model(args, supervised=False): if model.input_map.axes() == 2: model.plot_ecgs(num_rows=2, prefix=os.path.dirname(checkpoint_path)) else: - model.plot_images(num_rows=2, prefix=os.path.dirname(checkpoint_path)) + model.plot_images(num_cols=sides, num_rows=sides, prefix=os.path.dirname(checkpoint_path)) for tm_out, model_file in zip(args.tensor_maps_out, args.model_files): args.tensor_maps_out = [tm_out]