@@ -178,14 +178,17 @@ def train_diffusion_model(args):
178
178
metrics = model .evaluate (generate_test , batch_size = args .batch_size , steps = args .test_steps , return_dict = True )
179
179
logging .info (f'Test metrics: { metrics } ' )
180
180
181
- data , labels , paths = big_batch_from_minibatch_generator (generate_test , args .test_steps )
182
- preds = model .plot_reconstructions ((data , labels ), prefix = f'{ args .output_folder } /{ args .id } /reconstructions/' )
181
+ steps = 1 if args .batch_size > 3 else args .test_steps
182
+ data , labels , paths = big_batch_from_minibatch_generator (generate_test , steps )
183
+ sides = int (args .batch_size )
184
+ preds = model .plot_reconstructions ((data , labels ), num_rows = sides , num_cols = sides ,
185
+ prefix = f'{ args .output_folder } /{ args .id } /reconstructions/' )
183
186
image_out = {args .tensor_maps_in [0 ].output_name (): data [args .tensor_maps_in [0 ].input_name ()]}
184
187
predictions_to_pngs (preds , args .tensor_maps_in , args .tensor_maps_in , data , image_out , paths , f'{ args .output_folder } /{ args .id } /reconstructions/' )
185
188
if model .tensor_map .axes () == 2 :
186
189
model .plot_ecgs (num_rows = 4 , prefix = os .path .dirname (checkpoint_path ))
187
190
else :
188
- model .plot_images (num_cols = min ( 4 , args . batch_size ), num_rows = min ( 4 , args . test_steps ) , prefix = os .path .dirname (checkpoint_path ))
191
+ model .plot_images (num_cols = sides , num_rows = sides , prefix = os .path .dirname (checkpoint_path ))
189
192
return model
190
193
191
194
0 commit comments