Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 13, 2025
1 parent 0a32e0b commit 4f6f062
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ def plot_reconstructions(
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()
return generated_images


def control_plot_images(
Expand Down
5 changes: 4 additions & 1 deletion ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,10 @@ def train_diffusion_control_model(args, supervised=False):
if args.inspect_model:
data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1)

model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/')
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/')
#images = data[args.tensor_maps_in[0].input_name()]
image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/')
interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size,
f'{args.output_folder}/{args.id}/')
if model.input_map.axes() == 2:
Expand Down

0 comments on commit 4f6f062

Please sign in to comment.