Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 15, 2025
1 parent 6d8d1d4 commit b287451
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def train_diffusion_control_model(args, supervised=False):
model.load_weights(checkpoint_path)

if args.inspect_model:
metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True)
logging.info(f'Test metrics: {metrics}')

data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1)
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/')
image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
Expand Down

0 comments on commit b287451

Please sign in to comment.