|
| 1 | +#include "tensorflow/core/framework/op.h" |
| 2 | +#include "tensorflow/core/framework/op_kernel.h" |
| 3 | +REGISTER_OP("NnDistance") |
| 4 | + .Input("xyz1: float32") |
| 5 | + .Input("xyz2: float32") |
| 6 | + .Output("dist1: float32") |
| 7 | + .Output("idx1: int32") |
| 8 | + .Output("dist2: float32") |
| 9 | + .Output("idx2: int32"); |
| 10 | +REGISTER_OP("NnDistanceGrad") |
| 11 | + .Input("xyz1: float32") |
| 12 | + .Input("xyz2: float32") |
| 13 | + .Input("grad_dist1: float32") |
| 14 | + .Input("idx1: int32") |
| 15 | + .Input("grad_dist2: float32") |
| 16 | + .Input("idx2: int32") |
| 17 | + .Output("grad_xyz1: float32") |
| 18 | + .Output("grad_xyz2: float32"); |
| 19 | +using namespace tensorflow; |
| 20 | + |
| 21 | +static void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){ |
| 22 | + for (int i=0;i<b;i++){ |
| 23 | + for (int j=0;j<n;j++){ |
| 24 | + float x1=xyz1[(i*n+j)*3+0]; |
| 25 | + float y1=xyz1[(i*n+j)*3+1]; |
| 26 | + float z1=xyz1[(i*n+j)*3+2]; |
| 27 | + double best=0; |
| 28 | + int besti=0; |
| 29 | + for (int k=0;k<m;k++){ |
| 30 | + float x2=xyz2[(i*m+k)*3+0]-x1; |
| 31 | + float y2=xyz2[(i*m+k)*3+1]-y1; |
| 32 | + float z2=xyz2[(i*m+k)*3+2]-z1; |
| 33 | + double d=x2*x2+y2*y2+z2*z2; |
| 34 | + if (k==0 || d<best){ |
| 35 | + best=d; |
| 36 | + besti=k; |
| 37 | + } |
| 38 | + } |
| 39 | + dist[i*n+j]=best; |
| 40 | + idx[i*n+j]=besti; |
| 41 | + } |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +class NnDistanceOp : public OpKernel{ |
| 46 | + public: |
| 47 | + explicit NnDistanceOp(OpKernelConstruction* context):OpKernel(context){} |
| 48 | + void Compute(OpKernelContext * context)override{ |
| 49 | + const Tensor& xyz1_tensor=context->input(0); |
| 50 | + const Tensor& xyz2_tensor=context->input(1); |
| 51 | + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); |
| 52 | + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); |
| 53 | + int b=xyz1_tensor.shape().dim_size(0); |
| 54 | + int n=xyz1_tensor.shape().dim_size(1); |
| 55 | + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); |
| 56 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); |
| 57 | + int m=xyz2_tensor.shape().dim_size(1); |
| 58 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); |
| 59 | + auto xyz1_flat=xyz1_tensor.flat<float>(); |
| 60 | + const float * xyz1=&xyz1_flat(0); |
| 61 | + auto xyz2_flat=xyz2_tensor.flat<float>(); |
| 62 | + const float * xyz2=&xyz2_flat(0); |
| 63 | + Tensor * dist1_tensor=NULL; |
| 64 | + Tensor * idx1_tensor=NULL; |
| 65 | + Tensor * dist2_tensor=NULL; |
| 66 | + Tensor * idx2_tensor=NULL; |
| 67 | + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); |
| 68 | + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); |
| 69 | + auto dist1_flat=dist1_tensor->flat<float>(); |
| 70 | + auto idx1_flat=idx1_tensor->flat<int>(); |
| 71 | + OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); |
| 72 | + OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); |
| 73 | + auto dist2_flat=dist2_tensor->flat<float>(); |
| 74 | + auto idx2_flat=idx2_tensor->flat<int>(); |
| 75 | + float * dist1=&(dist1_flat(0)); |
| 76 | + int * idx1=&(idx1_flat(0)); |
| 77 | + float * dist2=&(dist2_flat(0)); |
| 78 | + int * idx2=&(idx2_flat(0)); |
| 79 | + nnsearch(b,n,m,xyz1,xyz2,dist1,idx1); |
| 80 | + nnsearch(b,m,n,xyz2,xyz1,dist2,idx2); |
| 81 | + } |
| 82 | +}; |
| 83 | +REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_CPU), NnDistanceOp); |
| 84 | +class NnDistanceGradOp : public OpKernel{ |
| 85 | + public: |
| 86 | + explicit NnDistanceGradOp(OpKernelConstruction* context):OpKernel(context){} |
| 87 | + void Compute(OpKernelContext * context)override{ |
| 88 | + const Tensor& xyz1_tensor=context->input(0); |
| 89 | + const Tensor& xyz2_tensor=context->input(1); |
| 90 | + const Tensor& grad_dist1_tensor=context->input(2); |
| 91 | + const Tensor& idx1_tensor=context->input(3); |
| 92 | + const Tensor& grad_dist2_tensor=context->input(4); |
| 93 | + const Tensor& idx2_tensor=context->input(5); |
| 94 | + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); |
| 95 | + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); |
| 96 | + int b=xyz1_tensor.shape().dim_size(0); |
| 97 | + int n=xyz1_tensor.shape().dim_size(1); |
| 98 | + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); |
| 99 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); |
| 100 | + int m=xyz2_tensor.shape().dim_size(1); |
| 101 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); |
| 102 | + OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); |
| 103 | + OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); |
| 104 | + OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); |
| 105 | + OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); |
| 106 | + auto xyz1_flat=xyz1_tensor.flat<float>(); |
| 107 | + const float * xyz1=&xyz1_flat(0); |
| 108 | + auto xyz2_flat=xyz2_tensor.flat<float>(); |
| 109 | + const float * xyz2=&xyz2_flat(0); |
| 110 | + auto idx1_flat=idx1_tensor.flat<int>(); |
| 111 | + const int * idx1=&idx1_flat(0); |
| 112 | + auto idx2_flat=idx2_tensor.flat<int>(); |
| 113 | + const int * idx2=&idx2_flat(0); |
| 114 | + auto grad_dist1_flat=grad_dist1_tensor.flat<float>(); |
| 115 | + const float * grad_dist1=&grad_dist1_flat(0); |
| 116 | + auto grad_dist2_flat=grad_dist2_tensor.flat<float>(); |
| 117 | + const float * grad_dist2=&grad_dist2_flat(0); |
| 118 | + Tensor * grad_xyz1_tensor=NULL; |
| 119 | + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); |
| 120 | + Tensor * grad_xyz2_tensor=NULL; |
| 121 | + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); |
| 122 | + auto grad_xyz1_flat=grad_xyz1_tensor->flat<float>(); |
| 123 | + float * grad_xyz1=&grad_xyz1_flat(0); |
| 124 | + auto grad_xyz2_flat=grad_xyz2_tensor->flat<float>(); |
| 125 | + float * grad_xyz2=&grad_xyz2_flat(0); |
| 126 | + for (int i=0;i<b*n*3;i++) |
| 127 | + grad_xyz1[i]=0; |
| 128 | + for (int i=0;i<b*m*3;i++) |
| 129 | + grad_xyz2[i]=0; |
| 130 | + for (int i=0;i<b;i++){ |
| 131 | + for (int j=0;j<n;j++){ |
| 132 | + float x1=xyz1[(i*n+j)*3+0]; |
| 133 | + float y1=xyz1[(i*n+j)*3+1]; |
| 134 | + float z1=xyz1[(i*n+j)*3+2]; |
| 135 | + int j2=idx1[i*n+j]; |
| 136 | + float x2=xyz2[(i*m+j2)*3+0]; |
| 137 | + float y2=xyz2[(i*m+j2)*3+1]; |
| 138 | + float z2=xyz2[(i*m+j2)*3+2]; |
| 139 | + float g=grad_dist1[i*n+j]*2; |
| 140 | + grad_xyz1[(i*n+j)*3+0]+=g*(x1-x2); |
| 141 | + grad_xyz1[(i*n+j)*3+1]+=g*(y1-y2); |
| 142 | + grad_xyz1[(i*n+j)*3+2]+=g*(z1-z2); |
| 143 | + grad_xyz2[(i*m+j2)*3+0]-=(g*(x1-x2)); |
| 144 | + grad_xyz2[(i*m+j2)*3+1]-=(g*(y1-y2)); |
| 145 | + grad_xyz2[(i*m+j2)*3+2]-=(g*(z1-z2)); |
| 146 | + } |
| 147 | + for (int j=0;j<m;j++){ |
| 148 | + float x1=xyz2[(i*m+j)*3+0]; |
| 149 | + float y1=xyz2[(i*m+j)*3+1]; |
| 150 | + float z1=xyz2[(i*m+j)*3+2]; |
| 151 | + int j2=idx2[i*m+j]; |
| 152 | + float x2=xyz1[(i*n+j2)*3+0]; |
| 153 | + float y2=xyz1[(i*n+j2)*3+1]; |
| 154 | + float z2=xyz1[(i*n+j2)*3+2]; |
| 155 | + float g=grad_dist2[i*m+j]*2; |
| 156 | + grad_xyz2[(i*m+j)*3+0]+=g*(x1-x2); |
| 157 | + grad_xyz2[(i*m+j)*3+1]+=g*(y1-y2); |
| 158 | + grad_xyz2[(i*m+j)*3+2]+=g*(z1-z2); |
| 159 | + grad_xyz1[(i*n+j2)*3+0]-=(g*(x1-x2)); |
| 160 | + grad_xyz1[(i*n+j2)*3+1]-=(g*(y1-y2)); |
| 161 | + grad_xyz1[(i*n+j2)*3+2]-=(g*(z1-z2)); |
| 162 | + } |
| 163 | + } |
| 164 | + } |
| 165 | +}; |
| 166 | +REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_CPU), NnDistanceGradOp); |
| 167 | + |
| 168 | +void NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i); |
| 169 | +class NnDistanceGpuOp : public OpKernel{ |
| 170 | + public: |
| 171 | + explicit NnDistanceGpuOp(OpKernelConstruction* context):OpKernel(context){} |
| 172 | + void Compute(OpKernelContext * context)override{ |
| 173 | + const Tensor& xyz1_tensor=context->input(0); |
| 174 | + const Tensor& xyz2_tensor=context->input(1); |
| 175 | + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz1 be of shape (batch,#points,3)")); |
| 176 | + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz1")); |
| 177 | + int b=xyz1_tensor.shape().dim_size(0); |
| 178 | + int n=xyz1_tensor.shape().dim_size(1); |
| 179 | + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistance requires xyz2 be of shape (batch,#points,3)")); |
| 180 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistance only accepts 3d point set xyz2")); |
| 181 | + int m=xyz2_tensor.shape().dim_size(1); |
| 182 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistance expects xyz1 and xyz2 have same batch size")); |
| 183 | + auto xyz1_flat=xyz1_tensor.flat<float>(); |
| 184 | + const float * xyz1=&xyz1_flat(0); |
| 185 | + auto xyz2_flat=xyz2_tensor.flat<float>(); |
| 186 | + const float * xyz2=&xyz2_flat(0); |
| 187 | + Tensor * dist1_tensor=NULL; |
| 188 | + Tensor * idx1_tensor=NULL; |
| 189 | + Tensor * dist2_tensor=NULL; |
| 190 | + Tensor * idx2_tensor=NULL; |
| 191 | + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n},&dist1_tensor)); |
| 192 | + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,n},&idx1_tensor)); |
| 193 | + auto dist1_flat=dist1_tensor->flat<float>(); |
| 194 | + auto idx1_flat=idx1_tensor->flat<int>(); |
| 195 | + OP_REQUIRES_OK(context,context->allocate_output(2,TensorShape{b,m},&dist2_tensor)); |
| 196 | + OP_REQUIRES_OK(context,context->allocate_output(3,TensorShape{b,m},&idx2_tensor)); |
| 197 | + auto dist2_flat=dist2_tensor->flat<float>(); |
| 198 | + auto idx2_flat=idx2_tensor->flat<int>(); |
| 199 | + float * dist1=&(dist1_flat(0)); |
| 200 | + int * idx1=&(idx1_flat(0)); |
| 201 | + float * dist2=&(dist2_flat(0)); |
| 202 | + int * idx2=&(idx2_flat(0)); |
| 203 | + NmDistanceKernelLauncher(b,n,xyz1,m,xyz2,dist1,idx1,dist2,idx2); |
| 204 | + } |
| 205 | +}; |
| 206 | +REGISTER_KERNEL_BUILDER(Name("NnDistance").Device(DEVICE_GPU), NnDistanceGpuOp); |
| 207 | + |
| 208 | +void NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2); |
| 209 | +class NnDistanceGradGpuOp : public OpKernel{ |
| 210 | + public: |
| 211 | + explicit NnDistanceGradGpuOp(OpKernelConstruction* context):OpKernel(context){} |
| 212 | + void Compute(OpKernelContext * context)override{ |
| 213 | + const Tensor& xyz1_tensor=context->input(0); |
| 214 | + const Tensor& xyz2_tensor=context->input(1); |
| 215 | + const Tensor& grad_dist1_tensor=context->input(2); |
| 216 | + const Tensor& idx1_tensor=context->input(3); |
| 217 | + const Tensor& grad_dist2_tensor=context->input(4); |
| 218 | + const Tensor& idx2_tensor=context->input(5); |
| 219 | + OP_REQUIRES(context,xyz1_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz1 be of shape (batch,#points,3)")); |
| 220 | + OP_REQUIRES(context,xyz1_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz1")); |
| 221 | + int b=xyz1_tensor.shape().dim_size(0); |
| 222 | + int n=xyz1_tensor.shape().dim_size(1); |
| 223 | + OP_REQUIRES(context,xyz2_tensor.dims()==3,errors::InvalidArgument("NnDistanceGrad requires xyz2 be of shape (batch,#points,3)")); |
| 224 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(2)==3,errors::InvalidArgument("NnDistanceGrad only accepts 3d point set xyz2")); |
| 225 | + int m=xyz2_tensor.shape().dim_size(1); |
| 226 | + OP_REQUIRES(context,xyz2_tensor.shape().dim_size(0)==b,errors::InvalidArgument("NnDistanceGrad expects xyz1 and xyz2 have same batch size")); |
| 227 | + OP_REQUIRES(context,grad_dist1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires grad_dist1 be of shape(batch,#points)")); |
| 228 | + OP_REQUIRES(context,idx1_tensor.shape()==(TensorShape{b,n}),errors::InvalidArgument("NnDistanceGrad requires idx1 be of shape(batch,#points)")); |
| 229 | + OP_REQUIRES(context,grad_dist2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires grad_dist2 be of shape(batch,#points)")); |
| 230 | + OP_REQUIRES(context,idx2_tensor.shape()==(TensorShape{b,m}),errors::InvalidArgument("NnDistanceGrad requires idx2 be of shape(batch,#points)")); |
| 231 | + auto xyz1_flat=xyz1_tensor.flat<float>(); |
| 232 | + const float * xyz1=&xyz1_flat(0); |
| 233 | + auto xyz2_flat=xyz2_tensor.flat<float>(); |
| 234 | + const float * xyz2=&xyz2_flat(0); |
| 235 | + auto idx1_flat=idx1_tensor.flat<int>(); |
| 236 | + const int * idx1=&idx1_flat(0); |
| 237 | + auto idx2_flat=idx2_tensor.flat<int>(); |
| 238 | + const int * idx2=&idx2_flat(0); |
| 239 | + auto grad_dist1_flat=grad_dist1_tensor.flat<float>(); |
| 240 | + const float * grad_dist1=&grad_dist1_flat(0); |
| 241 | + auto grad_dist2_flat=grad_dist2_tensor.flat<float>(); |
| 242 | + const float * grad_dist2=&grad_dist2_flat(0); |
| 243 | + Tensor * grad_xyz1_tensor=NULL; |
| 244 | + OP_REQUIRES_OK(context,context->allocate_output(0,TensorShape{b,n,3},&grad_xyz1_tensor)); |
| 245 | + Tensor * grad_xyz2_tensor=NULL; |
| 246 | + OP_REQUIRES_OK(context,context->allocate_output(1,TensorShape{b,m,3},&grad_xyz2_tensor)); |
| 247 | + auto grad_xyz1_flat=grad_xyz1_tensor->flat<float>(); |
| 248 | + float * grad_xyz1=&grad_xyz1_flat(0); |
| 249 | + auto grad_xyz2_flat=grad_xyz2_tensor->flat<float>(); |
| 250 | + float * grad_xyz2=&grad_xyz2_flat(0); |
| 251 | + NmDistanceGradKernelLauncher(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_dist2,idx2,grad_xyz1,grad_xyz2); |
| 252 | + } |
| 253 | +}; |
| 254 | +REGISTER_KERNEL_BUILDER(Name("NnDistanceGrad").Device(DEVICE_GPU), NnDistanceGradGpuOp); |
0 commit comments