Skip to content

Commit 5f7d4e8

Browse files
committed
upd
1 parent f7d4831 commit 5f7d4e8

File tree

3 files changed

+53
-11
lines changed

3 files changed

+53
-11
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.vscode/
12
# Byte-compiled / optimized / DLL files
23
__pycache__/
34
*.py[cod]

interpolation_kernel.cu

+10-7
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ __global__ void trilinear_fw_kernel(
2222
const scalar_t c = v*(1-w);
2323
const scalar_t d = 1-a-b-c;
2424
feat_interp[n][f] = (1-u)*(a*feats[n][0][f] +
25-
b*feats[n][1][f] +
26-
c*feats[n][2][f] +
27-
d*feats[n][3][f]) +
25+
b*feats[n][1][f] +
26+
c*feats[n][2][f] +
27+
d*feats[n][3][f]) +
2828
u*(a*feats[n][4][f] +
29-
b*feats[n][5][f] +
30-
c*feats[n][6][f] +
31-
d*feats[n][7][f]);
29+
b*feats[n][5][f] +
30+
c*feats[n][6][f] +
31+
d*feats[n][7][f]);
3232
}
3333

3434

@@ -39,6 +39,8 @@ torch::Tensor trilinear_fw_cu(
3939
const int N = feats.size(0), F = feats.size(2);
4040

4141
torch::Tensor feat_interp = torch::zeros({N, F}, feats.options());
42+
torch::Tensor feat_interp2 = torch::zeros({N, F}, feats.options());
43+
4244

4345
const dim3 threads(16, 16);
4446
const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y);
@@ -51,5 +53,6 @@ torch::Tensor trilinear_fw_cu(
5153
feat_interp.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>()
5254
);
5355
}));
54-
56+
57+
return feat_interp;
5558
}

test.py

+42-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,50 @@
11
import torch
22
import cppcuda_tutorial
3+
import time
4+
5+
6+
def trilinear_interpolation_py(feats, points):
7+
"""
8+
Inputs:
9+
feats: (N, 8, F)
10+
points: (N, 3) local coordinates in [-1, 1]
11+
12+
Outputs:
13+
feats_interp: (N, F)
14+
"""
15+
u = (points[:, 0:1]+1)/2
16+
v = (points[:, 1:2]+1)/2
17+
w = (points[:, 2:3]+1)/2
18+
a = (1-v)*(1-w)
19+
b = (1-v)*w
20+
c = v*(1-w)
21+
d = 1-a-b-c
22+
23+
feats_interp = (1-u)*(a*feats[:, 0] +
24+
b*feats[:, 1] +
25+
c*feats[:, 2] +
26+
d*feats[:, 3]) + \
27+
u*(a*feats[:, 4] +
28+
b*feats[:, 5] +
29+
c*feats[:, 6] +
30+
d*feats[:, 7])
31+
32+
return feats_interp
333

434

535
if __name__ == '__main__':
36+
N = 65536; F = 256
37+
feats = torch.rand(N, 8, F, device='cuda').requires_grad_()
38+
points = torch.rand(N, 3, device='cuda')*2-1
639

7-
feats = torch.ones(2, device='cuda')
8-
points = torch.zeros(2, device='cuda')
40+
t = time.time()
41+
out_cuda = cppcuda_tutorial.trilinear_interpolation(feats, points)
42+
torch.cuda.synchronize()
43+
print(' cuda time', time.time()-t, 's')
944

10-
out = cppcuda_tutorial.trilinear_interpolation(feats, points)
45+
t = time.time()
46+
out_py = trilinear_interpolation_py(feats, points)
47+
torch.cuda.synchronize()
48+
print('pytorch time', time.time()-t, 's')
1149

12-
print(out)
50+
print(torch.allclose(out_py, out_cuda))

0 commit comments

Comments
 (0)