diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index a7bf28fe2..780246e32 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -941,6 +941,7 @@ def plot_reconstructions( os.makedirs(os.path.dirname(figure_path)) plt.savefig(figure_path, bbox_inches="tight") plt.close() + return generated_images def control_plot_images( diff --git a/ml4h/models/train.py b/ml4h/models/train.py index e4c03f752..b6f910a95 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -375,7 +375,10 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) - model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') + preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/') + #images = data[args.tensor_maps_in[0].input_name()] + image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} + predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: