@@ -33,14 +33,12 @@ __global__ void trilinear_fw_kernel(
33
33
34
34
35
35
torch::Tensor trilinear_fw_cu (
36
- torch::Tensor feats,
37
- torch::Tensor points
36
+ const torch::Tensor feats,
37
+ const torch::Tensor points
38
38
){
39
39
const int N = feats.size (0 ), F = feats.size (2 );
40
40
41
41
torch::Tensor feat_interp = torch::zeros ({N, F}, feats.options ());
42
- torch::Tensor feat_interp2 = torch::zeros ({N, F}, feats.options ());
43
-
44
42
45
43
const dim3 threads (16 , 16 );
46
44
const dim3 blocks ((N+threads.x -1 )/threads.x , (F+threads.y -1 )/threads.y );
@@ -55,4 +53,63 @@ torch::Tensor trilinear_fw_cu(
55
53
}));
56
54
57
55
return feat_interp;
56
+ }
57
+
58
+
59
+ template <typename scalar_t >
60
+ __global__ void trilinear_bw_kernel (
61
+ const torch::PackedTensorAccessor<scalar_t , 2 , torch::RestrictPtrTraits, size_t > dL_dfeat_interp,
62
+ const torch::PackedTensorAccessor<scalar_t , 3 , torch::RestrictPtrTraits, size_t > feats,
63
+ const torch::PackedTensorAccessor<scalar_t , 2 , torch::RestrictPtrTraits, size_t > points,
64
+ torch::PackedTensorAccessor<scalar_t , 3 , torch::RestrictPtrTraits, size_t > dL_dfeats
65
+ ){
66
+ const int n = blockIdx .x * blockDim .x + threadIdx .x ;
67
+ const int f = blockIdx .y * blockDim .y + threadIdx .y ;
68
+
69
+ if (n>=feats.size (0 ) || f>=feats.size (2 )) return ;
70
+
71
+ // point -1~1
72
+ const scalar_t u = (points[n][0 ]+1 )/2 ;
73
+ const scalar_t v = (points[n][1 ]+1 )/2 ;
74
+ const scalar_t w = (points[n][2 ]+1 )/2 ;
75
+
76
+ const scalar_t a = (1 -v)*(1 -w);
77
+ const scalar_t b = (1 -v)*w;
78
+ const scalar_t c = v*(1 -w);
79
+ const scalar_t d = 1 -a-b-c;
80
+
81
+ dL_dfeats[n][0 ][f] = (1 -u)*a*dL_dfeat_interp[n][f];
82
+ dL_dfeats[n][1 ][f] = (1 -u)*b*dL_dfeat_interp[n][f];
83
+ dL_dfeats[n][2 ][f] = (1 -u)*c*dL_dfeat_interp[n][f];
84
+ dL_dfeats[n][3 ][f] = (1 -u)*d*dL_dfeat_interp[n][f];
85
+ dL_dfeats[n][4 ][f] = u*a*dL_dfeat_interp[n][f];
86
+ dL_dfeats[n][5 ][f] = u*b*dL_dfeat_interp[n][f];
87
+ dL_dfeats[n][6 ][f] = u*c*dL_dfeat_interp[n][f];
88
+ dL_dfeats[n][7 ][f] = u*d*dL_dfeat_interp[n][f];
89
+ }
90
+
91
+
92
+ torch::Tensor trilinear_bw_cu (
93
+ const torch::Tensor dL_dfeat_interp,
94
+ const torch::Tensor feats,
95
+ const torch::Tensor points
96
+ ){
97
+ const int N = feats.size (0 ), F = feats.size (2 );
98
+
99
+ torch::Tensor dL_dfeats = torch::zeros ({N, 8 , F}, feats.options ());
100
+
101
+ const dim3 threads (16 , 16 );
102
+ const dim3 blocks ((N+threads.x -1 )/threads.x , (F+threads.y -1 )/threads.y );
103
+
104
+ AT_DISPATCH_FLOATING_TYPES (feats.type (), " trilinear_bw_cu" ,
105
+ ([&] {
106
+ trilinear_bw_kernel<scalar_t ><<<blocks, threads>>> (
107
+ dL_dfeat_interp.packed_accessor <scalar_t , 2 , torch::RestrictPtrTraits, size_t >(),
108
+ feats.packed_accessor <scalar_t , 3 , torch::RestrictPtrTraits, size_t >(),
109
+ points.packed_accessor <scalar_t , 2 , torch::RestrictPtrTraits, size_t >(),
110
+ dL_dfeats.packed_accessor <scalar_t , 3 , torch::RestrictPtrTraits, size_t >()
111
+ );
112
+ }));
113
+
114
+ return dL_dfeats;
58
115
}
0 commit comments