diff --git a/mvn/utils/multiview.py b/mvn/utils/multiview.py index 8a146709..5ec7612c 100644 --- a/mvn/utils/multiview.py +++ b/mvn/utils/multiview.py @@ -32,7 +32,7 @@ def update_after_crop(self, bbox): def update_after_resize(self, image_shape, new_image_shape): height, width = image_shape - new_width, new_height = new_image_shape + new_height, new_width = new_image_shape fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] diff --git a/train.py b/train.py index 1c3c8605..932b492e 100644 --- a/train.py +++ b/train.py @@ -191,7 +191,7 @@ def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_ keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch) batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:]) - n_joints = keypoints_3d_pred[0].shape[1] + n_joints = keypoints_3d_pred.shape[1] keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32)