From 0a32e0bd65a33b9df77002382a0ea3f2f44edfd1 Mon Sep 17 00:00:00 2001 From: Sam Freesun Friedman Date: Mon, 13 Jan 2025 12:22:17 -0500 Subject: [PATCH] condition and supervise --- ml4h/models/diffusion_blocks.py | 17 ++++++++++++++++- ml4h/models/train.py | 2 -- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index cc4f35ff2..a7bf28fe2 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -922,12 +922,27 @@ def plot_reconstructions( plt.axis("off") plt.tight_layout() now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') - figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') + figure_path = os.path.join(prefix, f'diffusion_image_reconstructions_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") + plt.close() + plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300) + for row in range(num_rows): + for col in range(num_cols): + index = row * num_cols + col + plt.subplot(num_rows, num_cols, index + 1) + plt.imshow(images[index], cmap='gray') + plt.axis("off") + plt.tight_layout() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}') if not os.path.exists(os.path.dirname(figure_path)): os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") plt.close() + def control_plot_images( self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None, renoise=None, diff --git a/ml4h/models/train.py b/ml4h/models/train.py index f5d041396..e4c03f752 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -376,8 +376,6 @@ def train_diffusion_control_model(args, supervised=False): data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') - images = data[args.tensor_maps_in[0].input_name()] - predictions_to_pngs(images, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: