diff --git a/lib/knn/__init__.py b/lib/knn/__init__.py index ed16fe3b8..4e38c55f6 100755 --- a/lib/knn/__init__.py +++ b/lib/knn/__init__.py @@ -9,14 +9,15 @@ class KNearestNeighbor(Function): """ Compute k nearest neighbors for each query point. """ - def __init__(self, k): - self.k = k + # def __init__(self, k): + # self.k = k - def forward(self, ref, query): + @staticmethod + def forward(ctx, ref, query): ref = ref.float().cuda() query = query.float().cuda() - inds = torch.empty(query.shape[0], self.k, query.shape[2]).long().cuda() + inds = torch.empty(query.shape[0], 1, query.shape[2]).long().cuda() knn_pytorch.knn(ref, query, inds) diff --git a/lib/knn/src/knn.h b/lib/knn/src/knn.h index cc6d12efc..72918f2c8 100755 --- a/lib/knn/src/knn.h +++ b/lib/knn/src/knn.h @@ -35,7 +35,7 @@ int knn(at::Tensor& ref, at::Tensor& query, at::Tensor& idx) for (int b = 0; b < batch; b++) { knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, - dist_dev, idx_dev + b * k * query_nb, THCState_getCurrentStream(state)); + dist_dev, idx_dev + b * k * query_nb, at::cuda::getCurrentCUDAStream()); } THCudaFree(state, dist_dev); cudaError_t err = cudaGetLastError(); diff --git a/lib/loss.py b/lib/loss.py index 44a6d6fd0..4655a3ad7 100755 --- a/lib/loss.py +++ b/lib/loss.py @@ -10,7 +10,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list): - knn = KNearestNeighbor(1) + # knn = KNearestNeighbor(1) bs, num_p, _ = pred_c.size() pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1)) @@ -41,7 +41,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, if idx[0].item() in sym_list: target = target[0].transpose(1, 0).contiguous().view(3, -1) pred = pred.permute(2, 0, 1).contiguous().view(3, -1) - inds = knn(target.unsqueeze(0), pred.unsqueeze(0)) + inds = KNearestNeighbor.apply(target.unsqueeze(0), pred.unsqueeze(0)) target = torch.index_select(target, 1, inds.view(-1).detach() - 1) target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous() pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous() @@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item()) - del knn + # del knn return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()