Skip to content

Commit 4108dd9

Browse files
committed
tutorial 6
1 parent 5f7d4e8 commit 4108dd9

File tree

4 files changed

+127
-15
lines changed

4 files changed

+127
-15
lines changed

include/utils.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77

88
torch::Tensor trilinear_fw_cu(
9-
torch::Tensor feats,
10-
torch::Tensor points
9+
const torch::Tensor feats,
10+
const torch::Tensor points
11+
);
12+
13+
14+
torch::Tensor trilinear_bw_cu(
15+
const torch::Tensor dL_dfeat_interp,
16+
const torch::Tensor feats,
17+
const torch::Tensor points
1118
);

interpolation.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "utils.h"
22

33

4-
torch::Tensor trilinear_interpolation(
5-
torch::Tensor feats,
6-
torch::Tensor points
4+
torch::Tensor trilinear_interpolation_fw(
5+
const torch::Tensor feats,
6+
const torch::Tensor points
77
){
88
CHECK_INPUT(feats);
99
CHECK_INPUT(points);
@@ -12,6 +12,20 @@ torch::Tensor trilinear_interpolation(
1212
}
1313

1414

15+
torch::Tensor trilinear_interpolation_bw(
16+
const torch::Tensor dL_dfeat_interp,
17+
const torch::Tensor feats,
18+
const torch::Tensor points
19+
){
20+
CHECK_INPUT(dL_dfeat_interp);
21+
CHECK_INPUT(feats);
22+
CHECK_INPUT(points);
23+
24+
return trilinear_bw_cu(dL_dfeat_interp, feats, points);
25+
}
26+
27+
1528
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
16-
m.def("trilinear_interpolation", &trilinear_interpolation);
29+
m.def("trilinear_interpolation_fw", &trilinear_interpolation_fw);
30+
m.def("trilinear_interpolation_bw", &trilinear_interpolation_bw);
1731
}

interpolation_kernel.cu

+61-4
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,12 @@ __global__ void trilinear_fw_kernel(
3333

3434

3535
torch::Tensor trilinear_fw_cu(
36-
torch::Tensor feats,
37-
torch::Tensor points
36+
const torch::Tensor feats,
37+
const torch::Tensor points
3838
){
3939
const int N = feats.size(0), F = feats.size(2);
4040

4141
torch::Tensor feat_interp = torch::zeros({N, F}, feats.options());
42-
torch::Tensor feat_interp2 = torch::zeros({N, F}, feats.options());
43-
4442

4543
const dim3 threads(16, 16);
4644
const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y);
@@ -55,4 +53,63 @@ torch::Tensor trilinear_fw_cu(
5553
}));
5654

5755
return feat_interp;
56+
}
57+
58+
59+
template <typename scalar_t>
60+
__global__ void trilinear_bw_kernel(
61+
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> dL_dfeat_interp,
62+
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> feats,
63+
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> points,
64+
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> dL_dfeats
65+
){
66+
const int n = blockIdx.x * blockDim.x + threadIdx.x;
67+
const int f = blockIdx.y * blockDim.y + threadIdx.y;
68+
69+
if (n>=feats.size(0) || f>=feats.size(2)) return;
70+
71+
// point -1~1
72+
const scalar_t u = (points[n][0]+1)/2;
73+
const scalar_t v = (points[n][1]+1)/2;
74+
const scalar_t w = (points[n][2]+1)/2;
75+
76+
const scalar_t a = (1-v)*(1-w);
77+
const scalar_t b = (1-v)*w;
78+
const scalar_t c = v*(1-w);
79+
const scalar_t d = 1-a-b-c;
80+
81+
dL_dfeats[n][0][f] = (1-u)*a*dL_dfeat_interp[n][f];
82+
dL_dfeats[n][1][f] = (1-u)*b*dL_dfeat_interp[n][f];
83+
dL_dfeats[n][2][f] = (1-u)*c*dL_dfeat_interp[n][f];
84+
dL_dfeats[n][3][f] = (1-u)*d*dL_dfeat_interp[n][f];
85+
dL_dfeats[n][4][f] = u*a*dL_dfeat_interp[n][f];
86+
dL_dfeats[n][5][f] = u*b*dL_dfeat_interp[n][f];
87+
dL_dfeats[n][6][f] = u*c*dL_dfeat_interp[n][f];
88+
dL_dfeats[n][7][f] = u*d*dL_dfeat_interp[n][f];
89+
}
90+
91+
92+
torch::Tensor trilinear_bw_cu(
93+
const torch::Tensor dL_dfeat_interp,
94+
const torch::Tensor feats,
95+
const torch::Tensor points
96+
){
97+
const int N = feats.size(0), F = feats.size(2);
98+
99+
torch::Tensor dL_dfeats = torch::zeros({N, 8, F}, feats.options());
100+
101+
const dim3 threads(16, 16);
102+
const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y);
103+
104+
AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_bw_cu",
105+
([&] {
106+
trilinear_bw_kernel<scalar_t><<<blocks, threads>>>(
107+
dL_dfeat_interp.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
108+
feats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
109+
points.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
110+
dL_dfeats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>()
111+
);
112+
}));
113+
114+
return dL_dfeats;
58115
}

test.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,53 @@ def trilinear_interpolation_py(feats, points):
3232
return feats_interp
3333

3434

35+
class Trilinear_interpolation_cuda(torch.autograd.Function):
36+
@staticmethod
37+
def forward(ctx, feats, points):
38+
feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points)
39+
40+
ctx.save_for_backward(feats, points)
41+
42+
return feat_interp
43+
44+
@staticmethod
45+
def backward(ctx, dL_dfeat_interp):
46+
feats, points = ctx.saved_tensors
47+
48+
dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points)
49+
50+
return dL_dfeats, None
51+
52+
3553
if __name__ == '__main__':
3654
N = 65536; F = 256
37-
feats = torch.rand(N, 8, F, device='cuda').requires_grad_()
55+
rand = torch.rand(N, 8, F, device='cuda')
56+
feats = rand.clone().requires_grad_()
57+
feats2 = rand.clone().requires_grad_()
3858
points = torch.rand(N, 3, device='cuda')*2-1
3959

4060
t = time.time()
41-
out_cuda = cppcuda_tutorial.trilinear_interpolation(feats, points)
61+
out_cuda = Trilinear_interpolation_cuda.apply(feats2, points)
4262
torch.cuda.synchronize()
43-
print(' cuda time', time.time()-t, 's')
63+
print(' cuda fw time', time.time()-t, 's')
4464

4565
t = time.time()
4666
out_py = trilinear_interpolation_py(feats, points)
4767
torch.cuda.synchronize()
48-
print('pytorch time', time.time()-t, 's')
68+
print('pytorch fw time', time.time()-t, 's')
69+
70+
print('fw all close', torch.allclose(out_py, out_cuda))
71+
72+
t = time.time()
73+
loss2 = out_cuda.sum()
74+
loss2.backward()
75+
torch.cuda.synchronize()
76+
print(' cuda bw time', time.time()-t, 's')
77+
78+
t = time.time()
79+
loss = out_py.sum()
80+
loss.backward()
81+
torch.cuda.synchronize()
82+
print('pytorch bw time', time.time()-t, 's')
4983

50-
print(torch.allclose(out_py, out_cuda))
84+
print('bw all close', torch.allclose(feats.grad, feats2.grad))

0 commit comments

Comments
 (0)