Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 16, 2025
1 parent 2a3c520 commit 5e67b6f
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,17 @@ def train_diffusion_model(args):
metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True)
logging.info(f'Test metrics: {metrics}')

data, labels, paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/')
steps = 1 if args.batch_size > 3 else args.test_steps
data, labels, paths = big_batch_from_minibatch_generator(generate_test, steps)
sides = int(args.batch_size)
preds = model.plot_reconstructions((data, labels), num_rows=sides, num_cols=sides,
prefix=f'{args.output_folder}/{args.id}/reconstructions/')
image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/')
if model.tensor_map.axes() == 2:
model.plot_ecgs(num_rows=4, prefix=os.path.dirname(checkpoint_path))
else:
model.plot_images(num_cols=min(4,args.batch_size), num_rows=min(4,args.test_steps), prefix=os.path.dirname(checkpoint_path))
model.plot_images(num_cols=sides, num_rows=sides, prefix=os.path.dirname(checkpoint_path))
return model


Expand Down

0 comments on commit 5e67b6f

Please sign in to comment.