diff --git a/main.py b/main.py index a03f0d2..2211c58 100755 --- a/main.py +++ b/main.py @@ -377,10 +377,9 @@ def inference(args, epoch, data_loader, model, offset=0): statistics.append(loss_values) # import IPython; IPython.embed() if args.save_flow or args.render_validation: - for i in range(args.inference_batch_size): + for i in range(args.effective_inference_batch_size): _pflow = output[i].data.cpu().numpy().transpose(1, 2, 0) - flow_utils.writeFlow( join(flow_folder, '%06d.flo'%(batch_idx * args.inference_batch_size + i)), _pflow) - + flow_utils.writeFlow( join(flow_folder, '%06d.flo'%(batch_idx * args.effective_inference_batch_size + i)), _pflow) progress.set_description('Inference Averages for Epoch {}: '.format(epoch) + tools.format_dictionary_of_losses(loss_labels, np.array(statistics).mean(axis=0))) progress.update(1)