Skip to content

Commit 5e67b6f

Browse files
committed
condition and supervise
1 parent 2a3c520 commit 5e67b6f

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

ml4h/models/train.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,17 @@ def train_diffusion_model(args):
178178
metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True)
179179
logging.info(f'Test metrics: {metrics}')
180180

181-
data, labels, paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
182-
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/')
181+
steps = 1 if args.batch_size > 3 else args.test_steps
182+
data, labels, paths = big_batch_from_minibatch_generator(generate_test, steps)
183+
sides = int(args.batch_size)
184+
preds = model.plot_reconstructions((data, labels), num_rows=sides, num_cols=sides,
185+
prefix=f'{args.output_folder}/{args.id}/reconstructions/')
183186
image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
184187
predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/')
185188
if model.tensor_map.axes() == 2:
186189
model.plot_ecgs(num_rows=4, prefix=os.path.dirname(checkpoint_path))
187190
else:
188-
model.plot_images(num_cols=min(4,args.batch_size), num_rows=min(4,args.test_steps), prefix=os.path.dirname(checkpoint_path))
191+
model.plot_images(num_cols=sides, num_rows=sides, prefix=os.path.dirname(checkpoint_path))
189192
return model
190193

191194

0 commit comments

Comments
 (0)