Skip to content

Commit e53393c

Browse files
committed
add 4th video content
1 parent 03dac60 commit e53393c

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

interpolation_kernel.cu

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,40 @@
11
#include <torch/extension.h>
22

33

4+
template <typename scalar_t>
5+
__global__ void trilinear_fw_kernel(
6+
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> feats,
7+
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> points,
8+
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> feat_interp
9+
){
10+
const int n = blockIdx.x * blockDim.x + threadIdx.x;
11+
const int f = blockIdx.y * blockDim.y + threadIdx.y;
12+
13+
if (n>=feats.size(0) || f>=feats.size(2)) return;
14+
15+
// point -1~1
16+
const scalar_t u = (points[n][0]+1)/2;
17+
const scalar_t v = (points[n][1]+1)/2;
18+
const scalar_t w = (points[n][2]+1)/2;
19+
20+
const scalar_t a = (1-v)*(1-w);
21+
const scalar_t b = (1-v)*w;
22+
const scalar_t c = v*(1-w);
23+
const scalar_t d = 1-a-b-c;
24+
feat_interp[n][f] = (1-u)*(a*feats[n][0][f] +
25+
b*feats[n][1][f] +
26+
c*feats[n][2][f] +
27+
d*feats[n][3][f]) +
28+
u*(a*feats[n][4][f] +
29+
b*feats[n][5][f] +
30+
c*feats[n][6][f] +
31+
d*feats[n][7][f]);
32+
}
33+
34+
435
torch::Tensor trilinear_fw_cu(
536
torch::Tensor feats,
6-
torch::Tensor points
37+
torch::Tensor points,
738
){
839
const int N = feats.size(0), F = feats.size(2);
940

0 commit comments

Comments
 (0)